mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:40:19 +00:00
Add support for cursor for AMQP consumers
Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -33,6 +33,9 @@ func PrefixedClientID(connID string) string {
|
||||
type QueueManager interface {
|
||||
Start(ctx context.Context) error
|
||||
Stop() error
|
||||
CreateQueue(ctx context.Context, config qtypes.QueueConfig) error
|
||||
UpdateQueue(ctx context.Context, config qtypes.QueueConfig) error
|
||||
GetQueue(ctx context.Context, queueName string) (*qtypes.QueueConfig, error)
|
||||
Publish(ctx context.Context, topic string, payload []byte, properties map[string]string) error
|
||||
Subscribe(ctx context.Context, queueName, pattern, clientID, groupID, proxyNodeID string) error
|
||||
SubscribeWithCursor(ctx context.Context, queueName, pattern, clientID, groupID, proxyNodeID string, cursor *qtypes.CursorOption) error
|
||||
|
||||
+302
-6
@@ -8,11 +8,15 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/amqp091/codec"
|
||||
qtypes "github.com/absmach/fluxmq/queue/types"
|
||||
"github.com/absmach/fluxmq/storage"
|
||||
"github.com/absmach/fluxmq/topics"
|
||||
)
|
||||
@@ -28,6 +32,12 @@ type consumer struct {
|
||||
exclusive bool
|
||||
}
|
||||
|
||||
type queueInfo struct {
|
||||
name string
|
||||
queueType string
|
||||
args map[string]interface{}
|
||||
}
|
||||
|
||||
// exchange represents a declared exchange (in-memory, per-connection for now).
|
||||
type exchange struct {
|
||||
name string
|
||||
@@ -52,7 +62,7 @@ type Channel struct {
|
||||
|
||||
// Exchange/queue/binding state (local to this connection for non-durable)
|
||||
exchanges map[string]*exchange
|
||||
queues map[string]bool // set of declared queue names
|
||||
queues map[string]*queueInfo // declared queues by name
|
||||
bindings []binding
|
||||
exchangeMu sync.RWMutex
|
||||
|
||||
@@ -115,7 +125,7 @@ func newChannel(c *Connection, id uint16) *Channel {
|
||||
conn: c,
|
||||
id: id,
|
||||
exchanges: make(map[string]*exchange),
|
||||
queues: make(map[string]bool),
|
||||
queues: make(map[string]*queueInfo),
|
||||
consumers: make(map[string]*consumer),
|
||||
unacked: make(map[uint64]*unackedDelivery),
|
||||
flow: true,
|
||||
@@ -335,6 +345,21 @@ func (ch *Channel) completePublish() {
|
||||
}
|
||||
}
|
||||
|
||||
// RabbitMQ-style stream queue publish: default exchange with routingKey == queue name.
|
||||
if exchangeName == "" && ch.isStreamQueue(routingKey) {
|
||||
qm := ch.conn.broker.getQueueManager()
|
||||
if qm != nil {
|
||||
queueTopic := "$queue/" + routingKey
|
||||
if err := qm.Publish(context.Background(), queueTopic, body, props); err != nil {
|
||||
ch.conn.logger.Error("queue publish failed", "queue", routingKey, "error", err)
|
||||
}
|
||||
if ch.confirmMode {
|
||||
ch.sendPublisherAck()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this targets a queue via exchange bindings
|
||||
isQueuePublish := false
|
||||
ch.exchangeMu.RLock()
|
||||
@@ -549,12 +574,26 @@ func (ch *Channel) sendDelivery(cons *consumer, topic string, payload []byte, pr
|
||||
return err
|
||||
}
|
||||
|
||||
headers := make(map[string]interface{})
|
||||
for k, v := range props {
|
||||
switch k {
|
||||
case "content-type", "content-encoding", "correlation-id", "reply-to", "message-id", "type":
|
||||
continue
|
||||
default:
|
||||
headers[k] = v
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
headers = nil
|
||||
}
|
||||
|
||||
properties := codec.BasicProperties{
|
||||
ContentType: props["content-type"],
|
||||
CorrelationID: props["correlation-id"],
|
||||
ReplyTo: props["reply-to"],
|
||||
MessageID: props["message-id"],
|
||||
Type: props["type"],
|
||||
Headers: headers,
|
||||
}
|
||||
|
||||
headerFrame, err := buildContentHeaderFrame(ch.id, uint64(len(payload)), properties)
|
||||
@@ -769,10 +808,45 @@ func (ch *Channel) handleQueueDeclare(m *codec.QueueDeclare) error {
|
||||
m.Queue = fmt.Sprintf("amq.gen-%s-%d", ch.conn.connID, seq)
|
||||
}
|
||||
|
||||
queueType := extractQueueType(m.Arguments)
|
||||
|
||||
ch.exchangeMu.Lock()
|
||||
ch.queues[m.Queue] = true
|
||||
ch.queues[m.Queue] = &queueInfo{
|
||||
name: m.Queue,
|
||||
queueType: queueType,
|
||||
args: m.Arguments,
|
||||
}
|
||||
ch.exchangeMu.Unlock()
|
||||
|
||||
if queueType == string(qtypes.QueueTypeStream) {
|
||||
qm := ch.conn.broker.getQueueManager()
|
||||
if qm != nil {
|
||||
queueTopicPattern := "$queue/" + m.Queue + "/#"
|
||||
var cfg qtypes.QueueConfig
|
||||
if m.Durable {
|
||||
cfg = qtypes.DefaultQueueConfig(m.Queue, queueTopicPattern)
|
||||
} else {
|
||||
cfg = qtypes.DefaultEphemeralQueueConfig(m.Queue, queueTopicPattern)
|
||||
}
|
||||
cfg.Type = qtypes.QueueTypeStream
|
||||
cfg.Retention = extractStreamRetention(m.Arguments)
|
||||
if err := qm.CreateQueue(context.Background(), cfg); err != nil {
|
||||
// If it already exists, attempt to update retention/type only.
|
||||
if existing, err := qm.GetQueue(context.Background(), m.Queue); err == nil && existing != nil {
|
||||
existing.Type = qtypes.QueueTypeStream
|
||||
existing.Retention = cfg.Retention
|
||||
if !m.Durable {
|
||||
existing.Durable = false
|
||||
if existing.ExpiresAfter == 0 {
|
||||
existing.ExpiresAfter = 5 * time.Minute
|
||||
}
|
||||
}
|
||||
_ = qm.UpdateQueue(context.Background(), *existing)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Auto-bind queue to default exchange with routing key = queue name
|
||||
ch.exchangeMu.Lock()
|
||||
ch.bindings = append(ch.bindings, binding{
|
||||
@@ -877,7 +951,19 @@ func (ch *Channel) handleBasicConsume(m *codec.BasicConsume) error {
|
||||
if isQueue {
|
||||
queueName, pattern = parseQueueFilter(queueFilter)
|
||||
}
|
||||
|
||||
queueInfo := ch.getQueueInfo(queueFilter)
|
||||
streamCursor, hasStreamOffset := extractStreamOffset(m.Arguments)
|
||||
isStream := (queueInfo != nil && queueInfo.queueType == string(qtypes.QueueTypeStream)) || hasStreamOffset
|
||||
if isStream && !isQueue {
|
||||
queueName = queueFilter
|
||||
pattern = ""
|
||||
}
|
||||
|
||||
groupID := extractConsumerGroup(m.Arguments)
|
||||
if isStream && groupID == "" {
|
||||
groupID = tag
|
||||
}
|
||||
|
||||
ch.consumersMu.Lock()
|
||||
if _, exists := ch.consumers[tag]; exists {
|
||||
@@ -900,15 +986,25 @@ func (ch *Channel) handleBasicConsume(m *codec.BasicConsume) error {
|
||||
|
||||
// Subscribe to the queue via queue manager
|
||||
qm := ch.conn.broker.getQueueManager()
|
||||
if qm != nil && isQueue && queueName != "" {
|
||||
if qm != nil && (isQueue || isStream) && queueName != "" {
|
||||
clientID := PrefixedClientID(ch.conn.connID)
|
||||
if err := qm.Subscribe(context.Background(), queueName, pattern, clientID, groupID, ""); err != nil {
|
||||
subGroupID := groupID
|
||||
if isStream {
|
||||
cursor := streamCursor
|
||||
if cursor == nil {
|
||||
cursor = &qtypes.CursorOption{Position: qtypes.CursorLatest}
|
||||
}
|
||||
cursor.Mode = qtypes.GroupModeStream
|
||||
if err := qm.SubscribeWithCursor(context.Background(), queueName, pattern, clientID, subGroupID, "", cursor); err != nil {
|
||||
ch.conn.logger.Error("queue subscribe with cursor failed", "queue", queueName, "error", err)
|
||||
}
|
||||
} else if err := qm.Subscribe(context.Background(), queueName, pattern, clientID, subGroupID, ""); err != nil {
|
||||
ch.conn.logger.Error("queue subscribe failed", "queue", queueName, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe via the topic router for pub/sub delivery (non-queue topics).
|
||||
if !isQueue {
|
||||
if !isQueue && !isStream {
|
||||
connID := ch.conn.connID
|
||||
ch.conn.broker.router.Subscribe(connID, queueFilter, 1, storage.SubscribeOptions{})
|
||||
}
|
||||
@@ -1172,3 +1268,203 @@ func extractConsumerGroup(args map[string]interface{}) string {
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *Channel) getQueueInfo(name string) *queueInfo {
|
||||
ch.exchangeMu.RLock()
|
||||
defer ch.exchangeMu.RUnlock()
|
||||
return ch.queues[name]
|
||||
}
|
||||
|
||||
func (ch *Channel) isStreamQueue(name string) bool {
|
||||
if info := ch.getQueueInfo(name); info != nil && info.queueType == string(qtypes.QueueTypeStream) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(name, "$queue/") {
|
||||
base := strings.TrimPrefix(name, "$queue/")
|
||||
if info := ch.getQueueInfo(base); info != nil && info.queueType == string(qtypes.QueueTypeStream) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractQueueType(args map[string]interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return string(qtypes.QueueTypeClassic)
|
||||
}
|
||||
val, ok := args["x-queue-type"]
|
||||
if !ok {
|
||||
return string(qtypes.QueueTypeClassic)
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return string(qtypes.QueueTypeClassic)
|
||||
}
|
||||
return strings.ToLower(v)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return string(qtypes.QueueTypeClassic)
|
||||
}
|
||||
return strings.ToLower(string(v))
|
||||
default:
|
||||
return string(qtypes.QueueTypeClassic)
|
||||
}
|
||||
}
|
||||
|
||||
func extractStreamRetention(args map[string]interface{}) qtypes.RetentionPolicy {
|
||||
var policy qtypes.RetentionPolicy
|
||||
if len(args) == 0 {
|
||||
return policy
|
||||
}
|
||||
|
||||
if val, ok := args["x-max-age"]; ok {
|
||||
if d, ok := parseDurationArg(val); ok {
|
||||
policy.RetentionTime = d
|
||||
}
|
||||
}
|
||||
if val, ok := args["x-max-length-bytes"]; ok {
|
||||
if n, ok := parseInt64Arg(val); ok {
|
||||
policy.RetentionBytes = n
|
||||
}
|
||||
}
|
||||
if val, ok := args["x-max-length"]; ok {
|
||||
if n, ok := parseInt64Arg(val); ok {
|
||||
policy.RetentionMessages = n
|
||||
}
|
||||
}
|
||||
|
||||
return policy
|
||||
}
|
||||
|
||||
func extractStreamOffset(args map[string]interface{}) (*qtypes.CursorOption, bool) {
|
||||
if len(args) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
val, ok := args["x-stream-offset"]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return parseStreamOffsetString(v)
|
||||
case []byte:
|
||||
return parseStreamOffsetString(string(v))
|
||||
case int:
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
|
||||
case int64:
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
|
||||
case uint64:
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: v}, true
|
||||
case uint32:
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
|
||||
case time.Time:
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorTimestamp, Timestamp: v}, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func parseStreamOffsetString(val string) (*qtypes.CursorOption, bool) {
|
||||
if val == "" {
|
||||
return nil, false
|
||||
}
|
||||
v := strings.ToLower(strings.TrimSpace(val))
|
||||
switch v {
|
||||
case "first":
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorEarliest}, true
|
||||
case "last", "next":
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorLatest}, true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(v, "offset=") {
|
||||
v = strings.TrimPrefix(v, "offset=")
|
||||
}
|
||||
if strings.HasPrefix(v, "timestamp=") {
|
||||
raw := strings.TrimPrefix(v, "timestamp=")
|
||||
if ts, ok := parseUnixTimestamp(raw); ok {
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorTimestamp, Timestamp: ts}, true
|
||||
}
|
||||
}
|
||||
|
||||
if off, err := strconv.ParseUint(v, 10, 64); err == nil {
|
||||
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: off}, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func parseUnixTimestamp(raw string) (time.Time, bool) {
|
||||
if raw == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
val, err := strconv.ParseInt(raw, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if val > 1e12 {
|
||||
return time.UnixMilli(val), true
|
||||
}
|
||||
return time.Unix(val, 0), true
|
||||
}
|
||||
|
||||
func parseDurationArg(val any) (time.Duration, bool) {
|
||||
switch v := val.(type) {
|
||||
case time.Duration:
|
||||
return v, true
|
||||
case string:
|
||||
trimmed := strings.TrimSpace(v)
|
||||
if trimmed == "" {
|
||||
return 0, false
|
||||
}
|
||||
if d, err := time.ParseDuration(trimmed); err == nil {
|
||||
return d, true
|
||||
}
|
||||
upper := strings.ToUpper(trimmed)
|
||||
if strings.HasSuffix(upper, "D") {
|
||||
num := strings.TrimSuffix(upper, "D")
|
||||
if f, err := strconv.ParseFloat(num, 64); err == nil {
|
||||
return time.Duration(f * float64(24*time.Hour)), true
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(upper, "W") {
|
||||
num := strings.TrimSuffix(upper, "W")
|
||||
if f, err := strconv.ParseFloat(num, 64); err == nil {
|
||||
return time.Duration(f * float64(7*24*time.Hour)), true
|
||||
}
|
||||
}
|
||||
case int:
|
||||
return time.Duration(v) * time.Second, true
|
||||
case int64:
|
||||
return time.Duration(v) * time.Second, true
|
||||
case uint64:
|
||||
return time.Duration(v) * time.Second, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func parseInt64Arg(val any) (int64, bool) {
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case uint64:
|
||||
if v > math.MaxInt64 {
|
||||
return 0, false
|
||||
}
|
||||
return int64(v), true
|
||||
case uint32:
|
||||
return int64(v), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
case []byte:
|
||||
if n, err := strconv.ParseInt(strings.TrimSpace(string(v)), 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/fluxmq/amqp091/codec"
|
||||
qtypes "github.com/absmach/fluxmq/queue/types"
|
||||
)
|
||||
|
||||
func newTestChannel(t *testing.T) (*Channel, *bytes.Buffer) {
|
||||
@@ -262,3 +263,25 @@ func TestExchangeNotFoundOnPublish(t *testing.T) {
|
||||
t.Fatalf("expected NotFound, got %d", closeMsg.ReplyCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStreamOffsetString(t *testing.T) {
|
||||
first, ok := parseStreamOffsetString("first")
|
||||
if !ok || first.Position != qtypes.CursorEarliest {
|
||||
t.Fatalf("expected earliest, got %+v", first)
|
||||
}
|
||||
|
||||
last, ok := parseStreamOffsetString("last")
|
||||
if !ok || last.Position != qtypes.CursorLatest {
|
||||
t.Fatalf("expected latest, got %+v", last)
|
||||
}
|
||||
|
||||
offset, ok := parseStreamOffsetString("offset=42")
|
||||
if !ok || offset.Position != qtypes.CursorOffset || offset.Offset != 42 {
|
||||
t.Fatalf("expected offset 42, got %+v", offset)
|
||||
}
|
||||
|
||||
ts, ok := parseStreamOffsetString("timestamp=1700000000")
|
||||
if !ok || ts.Position != qtypes.CursorTimestamp || ts.Timestamp.IsZero() {
|
||||
t.Fatalf("expected timestamp, got %+v", ts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package amqp091
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,6 +19,31 @@ type QueuePublishOptions struct {
|
||||
Properties map[string]string // Optional message properties
|
||||
}
|
||||
|
||||
// StreamQueueOptions configures a stream queue declaration.
|
||||
type StreamQueueOptions struct {
|
||||
Name string
|
||||
Durable bool
|
||||
AutoDelete bool
|
||||
Exclusive bool
|
||||
NoWait bool
|
||||
MaxAge string // e.g. "7D", "1h"
|
||||
MaxLengthBytes int64
|
||||
MaxLengthMessages int64
|
||||
}
|
||||
|
||||
// StreamConsumeOptions configures a stream queue subscription.
|
||||
type StreamConsumeOptions struct {
|
||||
QueueName string
|
||||
ConsumerGroup string
|
||||
Offset string // "first", "last", "next", "offset=123", "timestamp=..."
|
||||
AutoAck bool
|
||||
Exclusive bool
|
||||
NoLocal bool
|
||||
NoWait bool
|
||||
ConsumerTag string
|
||||
Arguments amqp091.Table
|
||||
}
|
||||
|
||||
// QueueMessageHandler is called when a queue message is received.
|
||||
type QueueMessageHandler func(msg *QueueMessage)
|
||||
|
||||
@@ -49,6 +75,34 @@ func (qm *QueueMessage) Reject() error {
|
||||
})
|
||||
}
|
||||
|
||||
// StreamOffset returns the stream offset if present.
|
||||
func (qm *QueueMessage) StreamOffset() (uint64, bool) {
|
||||
return headerUint64(qm.Headers, "x-stream-offset")
|
||||
}
|
||||
|
||||
// StreamTimestamp returns the stream timestamp (unix millis) if present.
|
||||
func (qm *QueueMessage) StreamTimestamp() (int64, bool) {
|
||||
if v, ok := headerInt64(qm.Headers, "x-stream-timestamp"); ok {
|
||||
return v, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// WorkAcked reports whether the primary work group has acknowledged this offset.
|
||||
func (qm *QueueMessage) WorkAcked() (bool, bool) {
|
||||
return headerBool(qm.Headers, "x-work-acked")
|
||||
}
|
||||
|
||||
// WorkCommittedOffset returns the primary group's committed offset if present.
|
||||
func (qm *QueueMessage) WorkCommittedOffset() (uint64, bool) {
|
||||
return headerUint64(qm.Headers, "x-work-committed-offset")
|
||||
}
|
||||
|
||||
// WorkGroup returns the primary work group name if present.
|
||||
func (qm *QueueMessage) WorkGroup() (string, bool) {
|
||||
return headerString(qm.Headers, "x-work-group")
|
||||
}
|
||||
|
||||
func (qm *QueueMessage) withChannelLock(fn func() error) error {
|
||||
if qm.client == nil {
|
||||
return fn()
|
||||
@@ -111,6 +165,66 @@ func (c *Client) PublishToQueueWithOptions(opts *QueuePublishOptions) error {
|
||||
return c.publish("", queueTopic, publishing, false, false)
|
||||
}
|
||||
|
||||
// PublishToStream publishes a message to a stream queue (RabbitMQ-style queue name).
|
||||
func (c *Client) PublishToStream(queueName string, payload []byte, props map[string]string) error {
|
||||
if !c.connected.Load() {
|
||||
return ErrNotConnected
|
||||
}
|
||||
if queueName == "" {
|
||||
return ErrInvalidQueueName
|
||||
}
|
||||
|
||||
publishing := amqp091.Publishing{
|
||||
Timestamp: time.Now(),
|
||||
Body: payload,
|
||||
}
|
||||
applyProperties(&publishing, props)
|
||||
return c.publish("", queueName, publishing, false, false)
|
||||
}
|
||||
|
||||
// DeclareStreamQueue declares a stream queue with retention settings.
|
||||
func (c *Client) DeclareStreamQueue(opts *StreamQueueOptions) (string, error) {
|
||||
if !c.connected.Load() {
|
||||
return "", ErrNotConnected
|
||||
}
|
||||
if opts == nil {
|
||||
return "", ErrInvalidQueueName
|
||||
}
|
||||
|
||||
args := amqp091.Table{
|
||||
"x-queue-type": "stream",
|
||||
}
|
||||
if opts.MaxAge != "" {
|
||||
args["x-max-age"] = opts.MaxAge
|
||||
}
|
||||
if opts.MaxLengthBytes > 0 {
|
||||
args["x-max-length-bytes"] = opts.MaxLengthBytes
|
||||
}
|
||||
if opts.MaxLengthMessages > 0 {
|
||||
args["x-max-length"] = opts.MaxLengthMessages
|
||||
}
|
||||
|
||||
ch, err := c.channel()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
c.chMu.Lock()
|
||||
q, err := ch.QueueDeclare(
|
||||
opts.Name,
|
||||
opts.Durable,
|
||||
opts.AutoDelete,
|
||||
opts.Exclusive,
|
||||
opts.NoWait,
|
||||
args,
|
||||
)
|
||||
c.chMu.Unlock()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return q.Name, nil
|
||||
}
|
||||
|
||||
// SubscribeToQueue subscribes to a durable queue with a consumer group.
|
||||
// The queueName should NOT include the "$queue/" prefix - it will be added automatically.
|
||||
// The handler will be called for each message received from the queue.
|
||||
@@ -154,6 +268,51 @@ func (c *Client) SubscribeToQueue(queueName, consumerGroup string, handler Queue
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubscribeToStream subscribes to a stream queue with cursor control.
|
||||
func (c *Client) SubscribeToStream(opts *StreamConsumeOptions, handler QueueMessageHandler) error {
|
||||
if !c.connected.Load() {
|
||||
return ErrNotConnected
|
||||
}
|
||||
if opts == nil || opts.QueueName == "" {
|
||||
return ErrInvalidQueueName
|
||||
}
|
||||
if handler == nil {
|
||||
return ErrNilHandler
|
||||
}
|
||||
|
||||
queueName := opts.QueueName
|
||||
c.subsMu.Lock()
|
||||
if _, exists := c.queueSubs[queueName]; exists {
|
||||
c.subsMu.Unlock()
|
||||
return ErrAlreadySubscribed
|
||||
}
|
||||
|
||||
consumerTag := opts.ConsumerTag
|
||||
if consumerTag == "" {
|
||||
consumerTag = "ctag-" + strings.ReplaceAll(queueName, "/", "-") + "-" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
sub := &queueSubscription{
|
||||
queueName: queueName,
|
||||
queueTopic: queueName,
|
||||
consumerGroup: opts.ConsumerGroup,
|
||||
consumerTag: consumerTag,
|
||||
handler: handler,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
c.queueSubs[queueName] = sub
|
||||
c.subsMu.Unlock()
|
||||
|
||||
if err := c.subscribeStream(sub, opts); err != nil {
|
||||
c.subsMu.Lock()
|
||||
delete(c.queueSubs, queueName)
|
||||
c.subsMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnsubscribeFromQueue unsubscribes from a durable queue.
|
||||
// The queueName should NOT include the "$queue/" prefix - it will be added automatically.
|
||||
func (c *Client) UnsubscribeFromQueue(queueName string) error {
|
||||
@@ -182,6 +341,31 @@ func (c *Client) UnsubscribeFromQueue(queueName string) error {
|
||||
return ch.Cancel(sub.consumerTag, false)
|
||||
}
|
||||
|
||||
// UnsubscribeFromStream unsubscribes from a stream queue.
|
||||
func (c *Client) UnsubscribeFromStream(queueName string) error {
|
||||
c.subsMu.Lock()
|
||||
sub, ok := c.queueSubs[queueName]
|
||||
if ok {
|
||||
delete(c.queueSubs, queueName)
|
||||
}
|
||||
c.subsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub.close()
|
||||
|
||||
ch, err := c.channel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.chMu.Lock()
|
||||
defer c.chMu.Unlock()
|
||||
return ch.Cancel(sub.consumerTag, false)
|
||||
}
|
||||
|
||||
func (c *Client) subscribeQueue(sub *queueSubscription) error {
|
||||
ch, err := c.channel()
|
||||
if err != nil {
|
||||
@@ -229,6 +413,72 @@ func (c *Client) subscribeQueue(sub *queueSubscription) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) subscribeStream(sub *queueSubscription, opts *StreamConsumeOptions) error {
|
||||
ch, err := c.channel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := amqp091.Table{}
|
||||
if opts != nil {
|
||||
if opts.ConsumerGroup != "" {
|
||||
args["x-consumer-group"] = opts.ConsumerGroup
|
||||
}
|
||||
if opts.Offset != "" {
|
||||
args["x-stream-offset"] = opts.Offset
|
||||
}
|
||||
for k, v := range opts.Arguments {
|
||||
args[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
autoAck := false
|
||||
exclusive := false
|
||||
noLocal := false
|
||||
noWait := false
|
||||
if opts != nil {
|
||||
autoAck = opts.AutoAck
|
||||
exclusive = opts.Exclusive
|
||||
noLocal = opts.NoLocal
|
||||
noWait = opts.NoWait
|
||||
}
|
||||
|
||||
c.chMu.Lock()
|
||||
deliveries, err := ch.Consume(
|
||||
sub.queueTopic,
|
||||
sub.consumerTag,
|
||||
autoAck,
|
||||
exclusive,
|
||||
noLocal,
|
||||
noWait,
|
||||
args,
|
||||
)
|
||||
c.chMu.Unlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-sub.done:
|
||||
return
|
||||
case d, ok := <-deliveries:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sub.handler(&QueueMessage{
|
||||
Delivery: d,
|
||||
queueName: sub.queueName,
|
||||
client: c,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeQueueTopic(queueName string) string {
|
||||
if strings.HasPrefix(queueName, "$queue/") {
|
||||
return queueName
|
||||
@@ -267,3 +517,99 @@ func applyProperties(p *amqp091.Publishing, props map[string]string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func headerUint64(headers amqp091.Table, key string) (uint64, bool) {
|
||||
if headers == nil {
|
||||
return 0, false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case uint64:
|
||||
return v, true
|
||||
case uint32:
|
||||
return uint64(v), true
|
||||
case int64:
|
||||
if v < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return uint64(v), true
|
||||
case int:
|
||||
if v < 0 {
|
||||
return 0, false
|
||||
}
|
||||
return uint64(v), true
|
||||
case string:
|
||||
if n, err := strconv.ParseUint(v, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func headerInt64(headers amqp091.Table, key string) (int64, bool) {
|
||||
if headers == nil {
|
||||
return 0, false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case uint64:
|
||||
if v > math.MaxInt64 {
|
||||
return 0, false
|
||||
}
|
||||
return int64(v), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func headerBool(headers amqp091.Table, key string) (bool, bool) {
|
||||
if headers == nil {
|
||||
return false, false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
return v, true
|
||||
case string:
|
||||
if v == "true" {
|
||||
return true, true
|
||||
}
|
||||
if v == "false" {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func headerString(headers amqp091.Table, key string) (string, bool) {
|
||||
if headers == nil {
|
||||
return "", false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return v, true
|
||||
case []byte:
|
||||
return string(v), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package amqp091
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
amqp091 "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
func TestQueueMessageStreamMetadata(t *testing.T) {
|
||||
msg := &QueueMessage{
|
||||
Delivery: amqp091.Delivery{
|
||||
Headers: amqp091.Table{
|
||||
"x-stream-offset": "42",
|
||||
"x-stream-timestamp": "1700000000000",
|
||||
"x-work-acked": "true",
|
||||
"x-work-committed-offset": "100",
|
||||
"x-work-group": "workers",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if off, ok := msg.StreamOffset(); !ok || off != 42 {
|
||||
t.Fatalf("expected stream offset 42, got %d (ok=%v)", off, ok)
|
||||
}
|
||||
if ts, ok := msg.StreamTimestamp(); !ok || ts != 1700000000000 {
|
||||
t.Fatalf("expected stream timestamp, got %d (ok=%v)", ts, ok)
|
||||
}
|
||||
if acked, ok := msg.WorkAcked(); !ok || !acked {
|
||||
t.Fatalf("expected work acked true, got %v (ok=%v)", acked, ok)
|
||||
}
|
||||
if committed, ok := msg.WorkCommittedOffset(); !ok || committed != 100 {
|
||||
t.Fatalf("expected committed offset 100, got %d (ok=%v)", committed, ok)
|
||||
}
|
||||
if group, ok := msg.WorkGroup(); !ok || group != "workers" {
|
||||
t.Fatalf("expected work group workers, got %q (ok=%v)", group, ok)
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ type QueueConsumerInfo struct {
|
||||
ConsumerID string // Consumer identifier (usually client ID)
|
||||
ClientID string // MQTT client ID
|
||||
Pattern string // Subscription pattern within the queue
|
||||
Mode string // Consumer group mode (queue or stream)
|
||||
ProxyNodeID string // Node where the consumer is connected
|
||||
RegisteredAt time.Time
|
||||
}
|
||||
|
||||
+10
-3
@@ -19,10 +19,10 @@ import (
|
||||
amqp091broker "github.com/absmach/fluxmq/amqp091/broker"
|
||||
"github.com/absmach/fluxmq/broker/webhook"
|
||||
"github.com/absmach/fluxmq/cluster"
|
||||
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
logStorage "github.com/absmach/fluxmq/logstorage"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
|
||||
mqtttls "github.com/absmach/fluxmq/pkg/tls"
|
||||
"github.com/absmach/fluxmq/queue"
|
||||
"github.com/absmach/fluxmq/queue/raft"
|
||||
@@ -47,8 +47,8 @@ import (
|
||||
|
||||
// messageDispatcher routes cluster-delivered messages to the appropriate protocol broker.
|
||||
type messageDispatcher struct {
|
||||
mqtt cluster.MessageHandler
|
||||
amqp *amqpbroker.Broker
|
||||
mqtt cluster.MessageHandler
|
||||
amqp *amqpbroker.Broker
|
||||
amqp091 *amqp091broker.Broker
|
||||
}
|
||||
|
||||
@@ -339,6 +339,8 @@ func main() {
|
||||
Name: qc.Name,
|
||||
Topics: qc.Topics,
|
||||
Reserved: qc.Reserved,
|
||||
Type: queueTypes.QueueType(qc.Type),
|
||||
PrimaryGroup: qc.PrimaryGroup,
|
||||
MaxMessageSize: qc.Limits.MaxMessageSize,
|
||||
MaxDepth: qc.Limits.MaxDepth,
|
||||
MessageTTL: qc.Limits.MessageTTL,
|
||||
@@ -348,6 +350,11 @@ func main() {
|
||||
Multiplier: qc.Retry.Multiplier,
|
||||
DLQEnabled: qc.DLQ.Enabled,
|
||||
DLQTopic: qc.DLQ.Topic,
|
||||
Retention: queueTypes.RetentionPolicy{
|
||||
RetentionTime: qc.Retention.MaxAge,
|
||||
RetentionBytes: qc.Retention.MaxLengthBytes,
|
||||
RetentionMessages: qc.Retention.MaxLengthMessages,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
+25
-15
@@ -15,25 +15,28 @@ import (
|
||||
|
||||
// Config holds all configuration for the MQTT broker.
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Broker BrokerConfig `yaml:"broker"`
|
||||
Session SessionConfig `yaml:"session"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Cluster ClusterConfig `yaml:"cluster"`
|
||||
Webhook WebhookConfig `yaml:"webhook"`
|
||||
RateLimit RateLimitConfig `yaml:"ratelimit"`
|
||||
Queues []QueueConfig `yaml:"queues"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Broker BrokerConfig `yaml:"broker"`
|
||||
Session SessionConfig `yaml:"session"`
|
||||
Log LogConfig `yaml:"log"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
Cluster ClusterConfig `yaml:"cluster"`
|
||||
Webhook WebhookConfig `yaml:"webhook"`
|
||||
RateLimit RateLimitConfig `yaml:"ratelimit"`
|
||||
Queues []QueueConfig `yaml:"queues"`
|
||||
}
|
||||
|
||||
// QueueConfig defines configuration for a persistent queue.
|
||||
type QueueConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Topics []string `yaml:"topics"`
|
||||
Reserved bool `yaml:"reserved"`
|
||||
Limits QueueLimits `yaml:"limits"`
|
||||
Retry QueueRetry `yaml:"retry"`
|
||||
DLQ QueueDLQ `yaml:"dlq"`
|
||||
Name string `yaml:"name"`
|
||||
Topics []string `yaml:"topics"`
|
||||
Reserved bool `yaml:"reserved"`
|
||||
Type string `yaml:"type"`
|
||||
PrimaryGroup string `yaml:"primary_group"`
|
||||
Retention QueueRetention `yaml:"retention"`
|
||||
Limits QueueLimits `yaml:"limits"`
|
||||
Retry QueueRetry `yaml:"retry"`
|
||||
DLQ QueueDLQ `yaml:"dlq"`
|
||||
}
|
||||
|
||||
// QueueLimits defines resource limits for a queue.
|
||||
@@ -57,6 +60,13 @@ type QueueDLQ struct {
|
||||
Topic string `yaml:"topic"`
|
||||
}
|
||||
|
||||
// QueueRetention defines retention policy for a queue.
|
||||
type QueueRetention struct {
|
||||
MaxAge time.Duration `yaml:"max_age"`
|
||||
MaxLengthBytes int64 `yaml:"max_length_bytes"`
|
||||
MaxLengthMessages int64 `yaml:"max_length_messages"`
|
||||
}
|
||||
|
||||
// RateLimitConfig holds rate limiting configuration.
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
@@ -520,6 +520,54 @@ func main() {
|
||||
- `SubscribeToQueue` passes the consumer group via `x-consumer-group` on `basic.consume`.
|
||||
- `Ack`, `Nack`, and `Reject` map to `basic.ack`, `basic.nack`, and `basic.reject`.
|
||||
|
||||
### Stream Queues (RabbitMQ-Compatible)
|
||||
|
||||
Stream queues provide log-style consumption with cursor offsets.
|
||||
Stream queue names follow RabbitMQ conventions (no `$queue/` prefix).
|
||||
Supported offsets: `first`, `last`, `next`, `offset=<n>`, `timestamp=<unix>`.
|
||||
|
||||
```go
|
||||
// Declare a stream queue
|
||||
qName, err := c.DeclareStreamQueue(&amqp091.StreamQueueOptions{
|
||||
Name: "events",
|
||||
Durable: true,
|
||||
MaxAge: "7D",
|
||||
MaxLengthBytes: 10 * 1024 * 1024 * 1024,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("stream queue: %s", qName)
|
||||
|
||||
// Consume from the beginning
|
||||
err = c.SubscribeToStream(&amqp091.StreamConsumeOptions{
|
||||
QueueName: "events",
|
||||
Offset: "first",
|
||||
}, func(msg *amqp091.QueueMessage) {
|
||||
if off, ok := msg.StreamOffset(); ok {
|
||||
log.Printf("offset=%d payload=%s", off, string(msg.Body))
|
||||
}
|
||||
_ = msg.Ack()
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Publish to the stream queue (RabbitMQ-style)
|
||||
if err := c.PublishToStream("events", []byte("hello"), nil); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
Stream deliveries include:
|
||||
- `x-stream-offset`
|
||||
- `x-stream-timestamp`
|
||||
- `x-work-acked` / `x-work-committed-offset`
|
||||
|
||||
The `x-work-*` fields report the configured primary work group’s committed offset.
|
||||
Convenience accessors are available on `QueueMessage`:
|
||||
`StreamOffset()`, `StreamTimestamp()`, `WorkAcked()`, `WorkCommittedOffset()`, `WorkGroup()`.
|
||||
|
||||
### Pub/Sub
|
||||
|
||||
```go
|
||||
|
||||
@@ -525,6 +525,12 @@ queues:
|
||||
- "orders/#"
|
||||
- "$queue/orders/#" # allow explicit queue publish
|
||||
reserved: false
|
||||
type: "stream" # classic|stream (optional)
|
||||
primary_group: "workers"
|
||||
retention:
|
||||
max_age: "168h"
|
||||
max_length_bytes: 10737418240
|
||||
max_length_messages: 1000000
|
||||
limits:
|
||||
max_message_size: 1048576
|
||||
max_depth: 100000
|
||||
@@ -554,6 +560,23 @@ queue, include `$queue/<name>/#` in the list.
|
||||
|
||||
Marks a queue as system-reserved (cannot be deleted via management APIs).
|
||||
|
||||
### type
|
||||
|
||||
Queue behavior mode:
|
||||
- `classic` (default): work-queue semantics (acks drive committed offset)
|
||||
- `stream`: append-only log semantics with cursor-based consumption
|
||||
|
||||
### primary_group
|
||||
|
||||
Consumer group name used to compute delivery status for stream consumers.
|
||||
|
||||
### retention
|
||||
|
||||
Retention policy for stream-style access:
|
||||
- `max_age`: time-based retention window
|
||||
- `max_length_bytes`: size-based retention window
|
||||
- `max_length_messages`: message-count window
|
||||
|
||||
### limits / retry / dlq
|
||||
|
||||
Per-queue limits, retry policy, and dead-letter behavior. If `dlq.topic` is
|
||||
|
||||
+65
-4
@@ -35,6 +35,7 @@ Queues provide durable, at-least-once delivery across protocols:
|
||||
- Redelivery via visibility timeouts and work stealing
|
||||
- DLQ handler exists but is not wired into the main delivery path
|
||||
- FIFO order per queue and per consumer group (single cursor)
|
||||
- Stream queues (RabbitMQ-compatible) for event-log consumption with cursor offsets
|
||||
|
||||
Queues are integrated with MQTT and AMQP:
|
||||
- MQTT uses `$queue/<name>/...` topics
|
||||
@@ -78,6 +79,14 @@ $queue/tasks/image-processing/$reject → Reject (DLQ wiring is planned)
|
||||
$dlq/{queue-name} → Dead-letter queue
|
||||
```
|
||||
|
||||
Stream queues use RabbitMQ-compatible queue names and arguments:
|
||||
|
||||
- Declare with `x-queue-type=stream`
|
||||
- Consume with `x-stream-offset`
|
||||
- Retention via `x-max-age`, `x-max-length-bytes`, `x-max-length`
|
||||
|
||||
The `$queue/<name>` prefix remains supported for legacy queue clients.
|
||||
|
||||
**Acknowledgment requirements**:
|
||||
- `message-id` and `group-id` must be provided in message properties.
|
||||
- For MQTT, these are MQTT v5 User Properties.
|
||||
@@ -90,6 +99,13 @@ Queue deliveries include properties:
|
||||
- `queue`
|
||||
- `offset`
|
||||
|
||||
Stream deliveries also include:
|
||||
- `x-stream-offset`
|
||||
- `x-stream-timestamp`
|
||||
- `x-work-acked` (based on the primary work group’s committed offset)
|
||||
- `x-work-committed-offset`
|
||||
- `x-work-group`
|
||||
|
||||
### Message Flow (Queue Publish)
|
||||
|
||||
```
|
||||
@@ -119,6 +135,46 @@ but DLQ moves are not currently triggered automatically.
|
||||
|
||||
---
|
||||
|
||||
## RabbitMQ Stream Compatibility
|
||||
|
||||
FluxMQ supports RabbitMQ-style stream queues at the protocol level:
|
||||
|
||||
- `queue.declare` with `x-queue-type=stream`
|
||||
- `basic.consume` with `x-stream-offset`
|
||||
- `x-max-age`, `x-max-length-bytes`, `x-max-length` for retention
|
||||
|
||||
Stream queues are append-only: acks do **not** delete messages. Messages are
|
||||
removed only when retention policies allow it, and only up to the safe
|
||||
truncation point for queue-mode consumers.
|
||||
|
||||
`x-stream-offset` supports:
|
||||
- `first`
|
||||
- `last`
|
||||
- `next`
|
||||
- `offset=<n>`
|
||||
- `timestamp=<unix-seconds|unix-millis>`
|
||||
|
||||
FluxMQ extensions for stream consumers:
|
||||
- `x-work-acked` and `x-work-committed-offset` to report delivery status for the
|
||||
configured primary work group.
|
||||
- `x-work-group` to identify the group used for status.
|
||||
- Optional `x-consumer-group` on `basic.consume` to persist a shared cursor.
|
||||
If omitted, the consumer tag is used as the stream group ID.
|
||||
|
||||
Primary work group is configured per queue (see configuration section) and is
|
||||
used only for delivery status reporting; it does not affect routing.
|
||||
|
||||
## Model Alignment
|
||||
|
||||
FluxMQ’s queue model aligns with:
|
||||
|
||||
- **Kafka**: append-only log + consumer groups + retention.
|
||||
- **Pulsar**: subscriptions + retention window for replay.
|
||||
- **NATS JetStream**: queue vs log semantics are defined by consumer mode.
|
||||
|
||||
Stream queues provide log-like semantics, while classic queue groups preserve
|
||||
work-queue behavior.
|
||||
|
||||
## Log Storage Engine
|
||||
|
||||
Queues are backed by an **append-only log** with segments and sparse indexes.
|
||||
@@ -151,10 +207,15 @@ Each queue has a single log (no partitions).
|
||||
|
||||
### Retention
|
||||
|
||||
The queue manager periodically truncates logs to the **minimum committed offset**
|
||||
across all consumer groups. Time-based or size-based retention policies exist in
|
||||
`logstorage` (segment manager retention) but are not wired into the runtime
|
||||
loop or exposed via the main config.
|
||||
The queue manager truncates logs to a **safe offset** that respects:
|
||||
|
||||
- The minimum committed offset across **queue-mode** consumer groups
|
||||
- The queue’s retention policy (time/size/message-count)
|
||||
|
||||
This means:
|
||||
- Queue-mode consumers never lose unacked data.
|
||||
- Stream consumers do not block truncation.
|
||||
- Retention keeps data available for event-log readers as configured.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
"github.com/absmach/fluxmq/server/tcp"
|
||||
"github.com/absmach/fluxmq/storage/badger"
|
||||
"github.com/absmach/fluxmq/testutil"
|
||||
@@ -40,7 +41,7 @@ func startBroker(t *testing.T, dataDir string, tcpPort int) *testBrokerInstance
|
||||
require.NoError(t, err)
|
||||
|
||||
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
|
||||
b := broker.NewBroker(store, nil, nullLogger, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(store, nil, nullLogger, nil, nil, nil, nil, config.SessionConfig{})
|
||||
|
||||
tcpAddr := fmt.Sprintf("127.0.0.1:%d", tcpPort)
|
||||
tcpCfg := tcp.Config{
|
||||
|
||||
@@ -100,6 +100,16 @@ func (a *Adapter) Store() *Store {
|
||||
return a.store
|
||||
}
|
||||
|
||||
// OffsetByTime returns the offset for the given timestamp.
|
||||
func (a *Adapter) OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
|
||||
return a.store.LookupByTime(queueName, ts)
|
||||
}
|
||||
|
||||
// OffsetBySize returns the offset to keep when enforcing size retention.
|
||||
func (a *Adapter) OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
|
||||
return a.store.RetentionOffsetBySize(queueName, retentionBytes)
|
||||
}
|
||||
|
||||
// QueueStore interface implementation
|
||||
|
||||
// CreateQueue creates a new queue with the given configuration.
|
||||
|
||||
@@ -81,6 +81,9 @@ func (s *ConsumerGroupStateStore) loadAll() error {
|
||||
if state.Cursor == nil {
|
||||
state.Cursor = &types.QueueCursor{}
|
||||
}
|
||||
if state.Mode == "" {
|
||||
state.Mode = types.GroupModeQueue
|
||||
}
|
||||
if state.PEL == nil {
|
||||
state.PEL = make(map[string][]*types.PendingEntry)
|
||||
}
|
||||
@@ -143,6 +146,9 @@ func (s *ConsumerGroupStateStore) loadGroup(queueName, groupID string) (*types.C
|
||||
if state.Cursor == nil {
|
||||
state.Cursor = &types.QueueCursor{}
|
||||
}
|
||||
if state.Mode == "" {
|
||||
state.Mode = types.GroupModeQueue
|
||||
}
|
||||
if state.PEL == nil {
|
||||
state.PEL = make(map[string][]*types.PendingEntry)
|
||||
}
|
||||
|
||||
@@ -443,6 +443,40 @@ func (m *SegmentManager) ApplyRetention() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetentionOffsetBySize returns the offset to keep when enforcing size retention.
|
||||
// It uses segment granularity (does not split segments).
|
||||
func (m *SegmentManager) RetentionOffsetBySize(retentionBytes int64) (uint64, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if retentionBytes <= 0 {
|
||||
return m.headOffset, nil
|
||||
}
|
||||
|
||||
var totalSize int64
|
||||
for _, seg := range m.segments {
|
||||
totalSize += seg.Size()
|
||||
}
|
||||
|
||||
if totalSize <= retentionBytes {
|
||||
return m.headOffset, nil
|
||||
}
|
||||
|
||||
sizeToTrim := totalSize - retentionBytes
|
||||
var trimmed int64
|
||||
|
||||
for _, seg := range m.segments {
|
||||
segSize := seg.Size()
|
||||
if trimmed+segSize < sizeToTrim {
|
||||
trimmed += segSize
|
||||
continue
|
||||
}
|
||||
return seg.BaseOffset(), nil
|
||||
}
|
||||
|
||||
return m.headOffset, nil
|
||||
}
|
||||
|
||||
// applyTimeRetention removes segments older than the cutoff time.
|
||||
func (m *SegmentManager) applyTimeRetention(cutoff time.Time) error {
|
||||
// Keep at least one segment (the active one)
|
||||
|
||||
@@ -244,6 +244,26 @@ func (s *Store) ReadRange(queueName string, startOffset, endOffset uint64, maxMe
|
||||
return manager.ReadRange(startOffset, endOffset, maxMessages)
|
||||
}
|
||||
|
||||
// LookupByTime finds the offset for the given timestamp.
|
||||
func (s *Store) LookupByTime(queueName string, ts time.Time) (uint64, error) {
|
||||
manager, err := s.getQueue(queueName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return manager.LookupByTime(ts)
|
||||
}
|
||||
|
||||
// RetentionOffsetBySize returns the offset to keep when enforcing size retention.
|
||||
func (s *Store) RetentionOffsetBySize(queueName string, retentionBytes int64) (uint64, error) {
|
||||
manager, err := s.getQueue(queueName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return manager.RetentionOffsetBySize(retentionBytes)
|
||||
}
|
||||
|
||||
// Head returns the head offset for a queue.
|
||||
func (s *Store) Head(queueName string) (uint64, error) {
|
||||
manager, err := s.getQueue(queueName)
|
||||
|
||||
+109
-1
@@ -21,6 +21,7 @@ var (
|
||||
ErrConsumerNotFound = errors.New("consumer not found")
|
||||
ErrMessageNotPending = errors.New("message not in pending list")
|
||||
ErrInvalidOffset = errors.New("invalid offset")
|
||||
ErrGroupModeMismatch = errors.New("consumer group mode mismatch")
|
||||
)
|
||||
|
||||
// Manager handles consumer group operations including claiming,
|
||||
@@ -68,10 +69,21 @@ func NewManager(queueStore storage.QueueStore, groupStore storage.ConsumerGroupS
|
||||
}
|
||||
|
||||
// GetOrCreateGroup retrieves or creates a consumer group.
|
||||
func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, pattern string) (*types.ConsumerGroupState, error) {
|
||||
func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, pattern string, mode types.ConsumerGroupMode) (*types.ConsumerGroupState, error) {
|
||||
// Try to get existing group
|
||||
group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID)
|
||||
if err == nil {
|
||||
if mode == "" {
|
||||
return group, nil
|
||||
}
|
||||
if group.Mode == "" {
|
||||
group.Mode = mode
|
||||
_ = m.groupStore.UpdateConsumerGroup(ctx, group)
|
||||
return group, nil
|
||||
}
|
||||
if group.Mode != mode {
|
||||
return nil, ErrGroupModeMismatch
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
@@ -82,6 +94,9 @@ func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, patt
|
||||
|
||||
// Create new group
|
||||
group = types.NewConsumerGroupState(queueName, groupID, pattern)
|
||||
if mode != "" {
|
||||
group.Mode = mode
|
||||
}
|
||||
|
||||
if err := m.groupStore.CreateConsumerGroup(ctx, group); err != nil {
|
||||
// Handle race condition - another process might have created it
|
||||
@@ -180,6 +195,70 @@ func (m *Manager) ClaimBatch(ctx context.Context, queueName, groupID, consumerID
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// ClaimBatchStream retrieves multiple messages for a stream consumer without PEL tracking.
|
||||
// It advances the cursor once per batch for efficiency.
|
||||
func (m *Manager) ClaimBatchStream(ctx context.Context, queueName, groupID, consumerID string, filter *Filter, limit int) ([]*types.Message, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if limit <= 0 {
|
||||
limit = m.config.ClaimBatchSize
|
||||
}
|
||||
|
||||
group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cursor := group.GetCursor()
|
||||
tail, err := m.queueStore.Tail(ctx, group.QueueName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var messages []*types.Message
|
||||
var newCursor uint64 = cursor.Cursor
|
||||
|
||||
for newCursor < tail && len(messages) < limit {
|
||||
offset := newCursor
|
||||
newCursor++
|
||||
|
||||
msg, err := m.queueStore.Read(ctx, group.QueueName, offset)
|
||||
if err != nil {
|
||||
if err == storage.ErrOffsetOutOfRange {
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
queueRoot := "$queue/" + group.QueueName
|
||||
routingKey := types.ExtractRoutingKey(msg.Topic, queueRoot)
|
||||
if !filter.Matches(routingKey) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
if len(messages) == 0 {
|
||||
return nil, ErrNoMessages
|
||||
}
|
||||
|
||||
if newCursor > cursor.Cursor {
|
||||
if err := m.groupStore.UpdateCursor(ctx, group.QueueName, group.ID, newCursor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Keep committed in sync for stream groups.
|
||||
if err := m.groupStore.UpdateCommitted(ctx, group.QueueName, group.ID, newCursor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// claimFromCursor tries to claim a message from the cursor position.
|
||||
func (m *Manager) claimFromCursor(ctx context.Context, group *types.ConsumerGroupState, consumerID string, filter *Filter) (*types.Message, error) {
|
||||
cursor := group.GetCursor()
|
||||
@@ -465,6 +544,35 @@ func (m *Manager) GetMinCommittedOffset(ctx context.Context, queueName string) (
|
||||
return minCommitted, nil
|
||||
}
|
||||
|
||||
// GetMinCommittedOffsetByMode returns the minimum committed offset for groups matching mode.
|
||||
// If no groups of that mode exist, returns the queue tail.
|
||||
func (m *Manager) GetMinCommittedOffsetByMode(ctx context.Context, queueName string, mode types.ConsumerGroupMode) (uint64, error) {
|
||||
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var minCommitted uint64
|
||||
first := true
|
||||
|
||||
for _, group := range groups {
|
||||
if mode != "" && group.Mode != mode {
|
||||
continue
|
||||
}
|
||||
cursor := group.GetCursor()
|
||||
if first || cursor.Committed < minCommitted {
|
||||
minCommitted = cursor.Committed
|
||||
first = false
|
||||
}
|
||||
}
|
||||
|
||||
if first {
|
||||
return m.queueStore.Tail(ctx, queueName)
|
||||
}
|
||||
|
||||
return minCommitted, nil
|
||||
}
|
||||
|
||||
// UpdateHeartbeat updates the heartbeat timestamp for a consumer.
|
||||
func (m *Manager) UpdateHeartbeat(ctx context.Context, queueName, groupID, consumerID string) error {
|
||||
m.mu.Lock()
|
||||
|
||||
+267
-17
@@ -7,6 +7,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -239,6 +241,11 @@ func (m *Manager) CreateQueue(ctx context.Context, config types.QueueConfig) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateQueue updates an existing queue.
|
||||
func (m *Manager) UpdateQueue(ctx context.Context, config types.QueueConfig) error {
|
||||
return m.queueStore.UpdateQueue(ctx, config)
|
||||
}
|
||||
|
||||
// GetOrCreateQueue gets or creates a queue with default configuration.
|
||||
func (m *Manager) GetOrCreateQueue(ctx context.Context, queueName string, topics ...string) (*types.QueueConfig, error) {
|
||||
// Try to get existing
|
||||
@@ -421,19 +428,34 @@ func (m *Manager) Enqueue(ctx context.Context, topic string, payload []byte, pro
|
||||
|
||||
// SubscribeWithCursor adds a consumer with explicit cursor positioning.
|
||||
func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern string, clientID, groupID, proxyNodeID string, cursor *types.CursorOption) error {
|
||||
mode := types.GroupModeQueue
|
||||
if cursor != nil && cursor.Mode != "" {
|
||||
mode = cursor.Mode
|
||||
}
|
||||
if cursor == nil || cursor.Position == types.CursorDefault {
|
||||
return m.Subscribe(ctx, queueName, pattern, clientID, groupID, proxyNodeID)
|
||||
if mode != types.GroupModeStream {
|
||||
return m.Subscribe(ctx, queueName, pattern, clientID, groupID, proxyNodeID)
|
||||
}
|
||||
cursor = &types.CursorOption{Position: types.CursorLatest, Mode: mode}
|
||||
}
|
||||
|
||||
// Ensure queue exists
|
||||
queueTopicPattern := "$queue/" + queueName + "/#"
|
||||
_, err := m.GetOrCreateQueue(ctx, queueName, queueTopicPattern)
|
||||
queueCfg, err := m.GetOrCreateQueue(ctx, queueName, queueTopicPattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get or create queue: %w", err)
|
||||
}
|
||||
if mode == types.GroupModeStream && queueCfg != nil && queueCfg.Type != types.QueueTypeStream {
|
||||
queueCfg.Type = types.QueueTypeStream
|
||||
_ = m.queueStore.UpdateQueue(ctx, *queueCfg)
|
||||
}
|
||||
|
||||
if groupID == "" {
|
||||
groupID = extractGroupFromClientID(clientID)
|
||||
if mode == types.GroupModeStream {
|
||||
groupID = clientID
|
||||
} else {
|
||||
groupID = extractGroupFromClientID(clientID)
|
||||
}
|
||||
}
|
||||
|
||||
patternGroupID := groupID
|
||||
@@ -441,7 +463,7 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
|
||||
patternGroupID = fmt.Sprintf("%s@%s", groupID, pattern)
|
||||
}
|
||||
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern)
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern, mode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -469,6 +491,12 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
|
||||
offset = tail
|
||||
}
|
||||
m.groupStore.UpdateCursor(ctx, queueName, group.ID, offset)
|
||||
case types.CursorTimestamp:
|
||||
if !cursor.Timestamp.IsZero() {
|
||||
if offset, err := m.offsetByTime(ctx, queueName, cursor.Timestamp); err == nil {
|
||||
m.groupStore.UpdateCursor(ctx, queueName, group.ID, offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.consumerManager.RegisterConsumer(ctx, queueName, group.ID, clientID, clientID, proxyNodeID); err != nil {
|
||||
@@ -485,6 +513,7 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
|
||||
ConsumerID: clientID,
|
||||
ClientID: clientID,
|
||||
Pattern: pattern,
|
||||
Mode: string(mode),
|
||||
ProxyNodeID: proxyNodeID,
|
||||
RegisteredAt: time.Now(),
|
||||
}
|
||||
@@ -501,7 +530,8 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
|
||||
slog.String("queue", queueName),
|
||||
slog.String("group", patternGroupID),
|
||||
slog.String("client", clientID),
|
||||
slog.String("cursor", fmt.Sprintf("%d", cursor.Position)))
|
||||
slog.String("cursor", fmt.Sprintf("%d", cursor.Position)),
|
||||
slog.String("mode", string(mode)))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -528,7 +558,7 @@ func (m *Manager) Subscribe(ctx context.Context, queueName, pattern string, clie
|
||||
}
|
||||
|
||||
// Get or create consumer group
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern)
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern, types.GroupModeQueue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -549,6 +579,7 @@ func (m *Manager) Subscribe(ctx context.Context, queueName, pattern string, clie
|
||||
ConsumerID: clientID,
|
||||
ClientID: clientID,
|
||||
Pattern: pattern,
|
||||
Mode: string(types.GroupModeQueue),
|
||||
ProxyNodeID: proxyNodeID,
|
||||
RegisteredAt: time.Now(),
|
||||
}
|
||||
@@ -622,6 +653,20 @@ func (m *Manager) Ack(ctx context.Context, queueName, messageID, groupID string)
|
||||
return err
|
||||
}
|
||||
|
||||
if groupID != "" {
|
||||
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
|
||||
if group.Mode == types.GroupModeStream {
|
||||
cursor := group.GetCursor()
|
||||
next := offset + 1
|
||||
if next > cursor.Cursor {
|
||||
_ = m.groupStore.UpdateCursor(ctx, queueName, group.ID, next)
|
||||
_ = m.groupStore.UpdateCommitted(ctx, queueName, group.ID, next)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the consumer that has this message pending
|
||||
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
|
||||
if err != nil {
|
||||
@@ -633,6 +678,15 @@ func (m *Manager) Ack(ctx context.Context, queueName, messageID, groupID string)
|
||||
if groupID != "" && group.ID != groupID {
|
||||
continue
|
||||
}
|
||||
if group.Mode == types.GroupModeStream {
|
||||
cursor := group.GetCursor()
|
||||
next := offset + 1
|
||||
if next > cursor.Cursor {
|
||||
_ = m.groupStore.UpdateCursor(ctx, queueName, group.ID, next)
|
||||
_ = m.groupStore.UpdateCommitted(ctx, queueName, group.ID, next)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find and ack the message
|
||||
for consumerID := range group.PEL {
|
||||
@@ -654,6 +708,14 @@ func (m *Manager) Nack(ctx context.Context, queueName, messageID, groupID string
|
||||
return err
|
||||
}
|
||||
|
||||
if groupID != "" {
|
||||
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
|
||||
if group.Mode == types.GroupModeStream {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -663,6 +725,9 @@ func (m *Manager) Nack(ctx context.Context, queueName, messageID, groupID string
|
||||
if groupID != "" && group.ID != groupID {
|
||||
continue
|
||||
}
|
||||
if group.Mode == types.GroupModeStream {
|
||||
return nil
|
||||
}
|
||||
|
||||
for consumerID := range group.PEL {
|
||||
err := m.consumerManager.Nack(ctx, queueName, group.ID, consumerID, offset)
|
||||
@@ -683,6 +748,14 @@ func (m *Manager) Reject(ctx context.Context, queueName, messageID, groupID, rea
|
||||
return err
|
||||
}
|
||||
|
||||
if groupID != "" {
|
||||
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
|
||||
if group.Mode == types.GroupModeStream {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -692,6 +765,9 @@ func (m *Manager) Reject(ctx context.Context, queueName, messageID, groupID, rea
|
||||
if groupID != "" && group.ID != groupID {
|
||||
continue
|
||||
}
|
||||
if group.Mode == types.GroupModeStream {
|
||||
return nil
|
||||
}
|
||||
|
||||
for consumerID := range group.PEL {
|
||||
err := m.consumerManager.Reject(ctx, queueName, group.ID, consumerID, offset, reason)
|
||||
@@ -761,6 +837,31 @@ func (m *Manager) deliverMessages() {
|
||||
}
|
||||
|
||||
for _, queueConfig := range queues {
|
||||
primaryGroup := strings.TrimSpace(queueConfig.PrimaryGroup)
|
||||
primaryCommitted := make(map[string]uint64)
|
||||
getPrimaryCommitted := func(pattern string) (uint64, bool) {
|
||||
if primaryGroup == "" {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
patternGroupID := primaryGroup
|
||||
if pattern != "" {
|
||||
patternGroupID = fmt.Sprintf("%s@%s", primaryGroup, pattern)
|
||||
}
|
||||
|
||||
if val, ok := primaryCommitted[patternGroupID]; ok {
|
||||
return val, true
|
||||
}
|
||||
|
||||
committed, err := m.consumerManager.GetCommittedOffset(ctx, queueConfig.Name, patternGroupID)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
primaryCommitted[patternGroupID] = committed
|
||||
return committed, true
|
||||
}
|
||||
|
||||
// Deliver to local consumer groups
|
||||
groups, err := m.groupStore.ListConsumerGroups(ctx, queueConfig.Name)
|
||||
if err != nil {
|
||||
@@ -768,7 +869,7 @@ func (m *Manager) deliverMessages() {
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
m.deliverToGroup(ctx, &queueConfig, group)
|
||||
m.deliverToGroup(ctx, &queueConfig, group, getPrimaryCommitted)
|
||||
}
|
||||
|
||||
// Also deliver to remote consumers registered in cluster
|
||||
@@ -801,8 +902,12 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
|
||||
}
|
||||
|
||||
for groupID, groupConsumers := range consumersByGroup {
|
||||
mode := types.GroupModeQueue
|
||||
if groupConsumers[0].Mode != "" {
|
||||
mode = types.ConsumerGroupMode(groupConsumers[0].Mode)
|
||||
}
|
||||
// Get or create a local consumer group state for tracking cursor
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, config.Name, groupID, groupConsumers[0].Pattern)
|
||||
group, err := m.consumerManager.GetOrCreateGroup(ctx, config.Name, groupID, groupConsumers[0].Pattern, mode)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -814,23 +919,62 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
|
||||
}
|
||||
|
||||
// Round-robin across remote consumers in this group
|
||||
var workCommitted uint64
|
||||
var hasWorkCommitted bool
|
||||
if group.Mode == types.GroupModeStream && config.PrimaryGroup != "" {
|
||||
patternGroupID := config.PrimaryGroup
|
||||
if group.Pattern != "" {
|
||||
patternGroupID = fmt.Sprintf("%s@%s", config.PrimaryGroup, group.Pattern)
|
||||
}
|
||||
if committed, err := m.consumerManager.GetCommittedOffset(ctx, config.Name, patternGroupID); err == nil {
|
||||
workCommitted = committed
|
||||
hasWorkCommitted = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, consumerInfo := range groupConsumers {
|
||||
// Claim messages for this remote consumer
|
||||
msgs, err := m.consumerManager.ClaimBatch(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
|
||||
var msgs []*types.Message
|
||||
var err error
|
||||
if group.Mode == types.GroupModeStream {
|
||||
msgs, err = m.consumerManager.ClaimBatchStream(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
|
||||
} else {
|
||||
msgs, err = m.consumerManager.ClaimBatch(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Route each message to the remote node
|
||||
for _, msg := range msgs {
|
||||
payload := msg.GetPayload()
|
||||
properties := msg.Properties
|
||||
if group.Mode == types.GroupModeStream {
|
||||
propsCopy := make(map[string]string, len(msg.Properties)+4)
|
||||
for k, v := range msg.Properties {
|
||||
propsCopy[k] = v
|
||||
}
|
||||
// Decorate properties for stream consumers.
|
||||
propsCopy["x-stream-offset"] = fmt.Sprintf("%d", msg.Sequence)
|
||||
if !msg.CreatedAt.IsZero() {
|
||||
propsCopy["x-stream-timestamp"] = fmt.Sprintf("%d", msg.CreatedAt.UnixMilli())
|
||||
}
|
||||
if hasWorkCommitted {
|
||||
propsCopy["x-work-committed-offset"] = fmt.Sprintf("%d", workCommitted)
|
||||
propsCopy["x-work-acked"] = strconv.FormatBool(msg.Sequence < workCommitted)
|
||||
propsCopy["x-work-group"] = config.PrimaryGroup
|
||||
}
|
||||
properties = propsCopy
|
||||
}
|
||||
|
||||
err := m.cluster.RouteQueueMessage(
|
||||
ctx,
|
||||
consumerInfo.ProxyNodeID,
|
||||
consumerInfo.ClientID,
|
||||
config.Name,
|
||||
msg.ID,
|
||||
msg.GetPayload(),
|
||||
msg.Properties,
|
||||
payload,
|
||||
properties,
|
||||
int64(msg.Sequence),
|
||||
)
|
||||
if err != nil {
|
||||
@@ -851,7 +995,7 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig, group *types.ConsumerGroupState) {
|
||||
func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig, group *types.ConsumerGroupState, primaryCommitted func(pattern string) (uint64, bool)) {
|
||||
if group.ConsumerCount() == 0 {
|
||||
return
|
||||
}
|
||||
@@ -869,7 +1013,13 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
|
||||
}
|
||||
|
||||
for _, consumerID := range consumers {
|
||||
msgs, err := m.consumerManager.ClaimBatch(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
|
||||
var msgs []*types.Message
|
||||
var err error
|
||||
if group.Mode == types.GroupModeStream {
|
||||
msgs, err = m.consumerManager.ClaimBatchStream(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
|
||||
} else {
|
||||
msgs, err = m.consumerManager.ClaimBatch(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -890,6 +1040,12 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
|
||||
m.consumerManager.UpdateHeartbeat(ctx, config.Name, group.ID, consumerID)
|
||||
}
|
||||
|
||||
var workCommitted uint64
|
||||
var hasWorkCommitted bool
|
||||
if group.Mode == types.GroupModeStream && primaryCommitted != nil {
|
||||
workCommitted, hasWorkCommitted = primaryCommitted(group.Pattern)
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
// Check if consumer is on a remote node
|
||||
if m.cluster != nil && consumerInfo.ProxyNodeID != "" && consumerInfo.ProxyNodeID != m.localNodeID {
|
||||
@@ -914,6 +1070,9 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
|
||||
} else if m.deliverFn != nil {
|
||||
// Local delivery
|
||||
deliveryMsg := m.createDeliveryMessage(msg, group.ID, config.Name)
|
||||
if group.Mode == types.GroupModeStream {
|
||||
m.decorateStreamDelivery(deliveryMsg, msg, group, workCommitted, hasWorkCommitted, config.PrimaryGroup)
|
||||
}
|
||||
|
||||
if err := m.deliverFn(ctx, consumerInfo.ClientID, deliveryMsg); err != nil {
|
||||
m.logger.Warn("queue message delivery failed",
|
||||
@@ -1009,14 +1168,21 @@ func (m *Manager) processRetention() {
|
||||
}
|
||||
|
||||
for _, queueConfig := range queues {
|
||||
// Get minimum committed offset across all groups
|
||||
minCommitted, err := m.consumerManager.GetMinCommittedOffset(ctx, queueConfig.Name)
|
||||
// Get minimum committed offset across queue-mode groups
|
||||
minCommitted, err := m.consumerManager.GetMinCommittedOffsetByMode(ctx, queueConfig.Name, types.GroupModeQueue)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Truncate log up to committed offset
|
||||
if err := m.queueStore.Truncate(ctx, queueConfig.Name, minCommitted); err != nil {
|
||||
truncateOffset := minCommitted
|
||||
if retentionOffset, hasRetention := m.computeRetentionOffset(ctx, &queueConfig); hasRetention {
|
||||
if retentionOffset < truncateOffset {
|
||||
truncateOffset = retentionOffset
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate log up to the safe offset
|
||||
if err := m.queueStore.Truncate(ctx, queueConfig.Name, truncateOffset); err != nil {
|
||||
m.logger.Debug("truncation error",
|
||||
slog.String("error", err.Error()),
|
||||
slog.String("queue", queueConfig.Name))
|
||||
@@ -1193,6 +1359,28 @@ func (m *Manager) createDeliveryMessage(msg *types.Message, groupID string, queu
|
||||
return deliveryMsg
|
||||
}
|
||||
|
||||
func (m *Manager) decorateStreamDelivery(delivery *brokerstorage.Message, msg *types.Message, group *types.ConsumerGroupState, workCommitted uint64, hasWorkCommitted bool, primaryGroup string) {
|
||||
if delivery == nil || msg == nil {
|
||||
return
|
||||
}
|
||||
if delivery.Properties == nil {
|
||||
delivery.Properties = make(map[string]string)
|
||||
}
|
||||
|
||||
delivery.Properties["x-stream-offset"] = fmt.Sprintf("%d", msg.Sequence)
|
||||
if !msg.CreatedAt.IsZero() {
|
||||
delivery.Properties["x-stream-timestamp"] = fmt.Sprintf("%d", msg.CreatedAt.UnixMilli())
|
||||
}
|
||||
|
||||
if hasWorkCommitted {
|
||||
delivery.Properties["x-work-committed-offset"] = fmt.Sprintf("%d", workCommitted)
|
||||
delivery.Properties["x-work-acked"] = strconv.FormatBool(msg.Sequence < workCommitted)
|
||||
if primaryGroup != "" {
|
||||
delivery.Properties["x-work-group"] = primaryGroup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeliveryMessage is the internal message format for queue delivery tracking.
|
||||
type DeliveryMessage struct {
|
||||
ID string
|
||||
@@ -1235,6 +1423,68 @@ func parseMessageID(messageID string) (uint64, error) {
|
||||
return offset, err
|
||||
}
|
||||
|
||||
func (m *Manager) offsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
|
||||
if provider, ok := m.queueStore.(storage.TimeOffsetProvider); ok {
|
||||
return provider.OffsetByTime(ctx, queueName, ts)
|
||||
}
|
||||
return m.queueStore.Head(ctx, queueName)
|
||||
}
|
||||
|
||||
func (m *Manager) offsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
|
||||
if provider, ok := m.queueStore.(storage.SizeOffsetProvider); ok {
|
||||
return provider.OffsetBySize(ctx, queueName, retentionBytes)
|
||||
}
|
||||
return m.queueStore.Head(ctx, queueName)
|
||||
}
|
||||
|
||||
func (m *Manager) computeRetentionOffset(ctx context.Context, config *types.QueueConfig) (uint64, bool) {
|
||||
if config == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
hasRetention := false
|
||||
|
||||
if config.Retention.RetentionTime > 0 {
|
||||
cutoff := time.Now().Add(-config.Retention.RetentionTime)
|
||||
if off, err := m.offsetByTime(ctx, config.Name, cutoff); err == nil {
|
||||
if off > offset {
|
||||
offset = off
|
||||
}
|
||||
hasRetention = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.Retention.RetentionBytes > 0 {
|
||||
if off, err := m.offsetBySize(ctx, config.Name, config.Retention.RetentionBytes); err == nil {
|
||||
if off > offset {
|
||||
offset = off
|
||||
}
|
||||
hasRetention = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.Retention.RetentionMessages > 0 {
|
||||
head, err := m.queueStore.Head(ctx, config.Name)
|
||||
if err == nil {
|
||||
tail, err := m.queueStore.Tail(ctx, config.Name)
|
||||
if err == nil {
|
||||
if tail > head+uint64(config.Retention.RetentionMessages) {
|
||||
msgOffset := tail - uint64(config.Retention.RetentionMessages)
|
||||
if msgOffset > offset {
|
||||
offset = msgOffset
|
||||
}
|
||||
} else if head > offset {
|
||||
offset = head
|
||||
}
|
||||
hasRetention = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return offset, hasRetention
|
||||
}
|
||||
|
||||
func parseSubscriptionFilter(filter string) (queueName, pattern string) {
|
||||
for i, c := range filter {
|
||||
if c == '/' {
|
||||
|
||||
@@ -5,6 +5,7 @@ package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -329,6 +330,119 @@ func TestWildcardQueueSubscription(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGroupDeliversWithoutPEL(t *testing.T) {
|
||||
logStore := memlog.New()
|
||||
groupStore := newMockGroupStore()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
delivered := make(chan *brokerstorage.Message, 1)
|
||||
|
||||
deliverFn := func(ctx context.Context, clientID string, msg any) error {
|
||||
if m, ok := msg.(*brokerstorage.Message); ok {
|
||||
delivered <- m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.DeliveryBatchSize = 10
|
||||
mgr := NewManager(logStore, groupStore, deliverFn, cfg, logger, nil)
|
||||
|
||||
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
|
||||
queueCfg.Type = types.QueueTypeStream
|
||||
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
|
||||
t.Fatalf("CreateQueue failed: %v", err)
|
||||
}
|
||||
|
||||
cursor := &types.CursorOption{Position: types.CursorEarliest, Mode: types.GroupModeStream}
|
||||
if err := mgr.SubscribeWithCursor(context.Background(), "events", "", "client-1", "streamer", "", cursor); err != nil {
|
||||
t.Fatalf("SubscribeWithCursor failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.Publish(context.Background(), "$queue/events/test", []byte("hello"), nil); err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
|
||||
mgr.deliverMessages()
|
||||
|
||||
select {
|
||||
case msg := <-delivered:
|
||||
if got := msg.Properties["x-stream-offset"]; got != "0" {
|
||||
t.Fatalf("expected stream offset 0, got %q", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for delivery")
|
||||
}
|
||||
|
||||
group, err := groupStore.GetConsumerGroup(context.Background(), "events", "streamer")
|
||||
if err != nil {
|
||||
t.Fatalf("GetConsumerGroup failed: %v", err)
|
||||
}
|
||||
if count := group.PendingCount(); count != 0 {
|
||||
t.Fatalf("expected no pending entries, got %d", count)
|
||||
}
|
||||
if cursor := group.GetCursor().Cursor; cursor != 1 {
|
||||
t.Fatalf("expected cursor 1, got %d", cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamAckAdvancesCursor(t *testing.T) {
|
||||
logStore := memlog.New()
|
||||
groupStore := newMockGroupStore()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
mgr := NewManager(logStore, groupStore, nil, DefaultConfig(), logger, nil)
|
||||
|
||||
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
|
||||
queueCfg.Type = types.QueueTypeStream
|
||||
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
|
||||
t.Fatalf("CreateQueue failed: %v", err)
|
||||
}
|
||||
|
||||
cursor := &types.CursorOption{Position: types.CursorEarliest, Mode: types.GroupModeStream}
|
||||
if err := mgr.SubscribeWithCursor(context.Background(), "events", "", "client-1", "streamer", "", cursor); err != nil {
|
||||
t.Fatalf("SubscribeWithCursor failed: %v", err)
|
||||
}
|
||||
|
||||
if err := mgr.Ack(context.Background(), "events", "events:0", "streamer"); err != nil {
|
||||
t.Fatalf("Ack failed: %v", err)
|
||||
}
|
||||
|
||||
group, err := groupStore.GetConsumerGroup(context.Background(), "events", "streamer")
|
||||
if err != nil {
|
||||
t.Fatalf("GetConsumerGroup failed: %v", err)
|
||||
}
|
||||
if cursor := group.GetCursor().Cursor; cursor != 1 {
|
||||
t.Fatalf("expected cursor 1, got %d", cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionOffsetMessages(t *testing.T) {
|
||||
logStore := memlog.New()
|
||||
groupStore := newMockGroupStore()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
mgr := NewManager(logStore, groupStore, nil, DefaultConfig(), logger, nil)
|
||||
|
||||
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
|
||||
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
|
||||
t.Fatalf("CreateQueue failed: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := mgr.Publish(context.Background(), "$queue/events/test", []byte("msg"), nil); err != nil {
|
||||
t.Fatalf("Publish failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
queueCfg.Retention.RetentionMessages = 1
|
||||
queueCfg.Name = "events"
|
||||
offset, ok := mgr.computeRetentionOffset(context.Background(), &queueCfg)
|
||||
if !ok {
|
||||
t.Fatal("expected retention offset")
|
||||
}
|
||||
if offset != 2 {
|
||||
t.Fatalf("expected retention offset 2, got %d", offset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExactQueueSubscription(t *testing.T) {
|
||||
logStore := memlog.New()
|
||||
groupStore := newMockGroupStore()
|
||||
|
||||
@@ -279,6 +279,54 @@ func (s *Store) Tail(ctx context.Context, queueName string) (uint64, error) {
|
||||
return sl.tail, nil
|
||||
}
|
||||
|
||||
// OffsetByTime returns the offset for the given timestamp.
|
||||
func (s *Store) OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
|
||||
sl, err := s.getQueueLog(queueName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
if len(sl.messages) == 0 {
|
||||
return sl.head, nil
|
||||
}
|
||||
|
||||
for i, msg := range sl.messages {
|
||||
if !msg.CreatedAt.IsZero() && !msg.CreatedAt.Before(ts) {
|
||||
return sl.head + uint64(i), nil
|
||||
}
|
||||
}
|
||||
|
||||
return sl.tail, nil
|
||||
}
|
||||
|
||||
// OffsetBySize returns the offset to keep when enforcing size retention.
|
||||
func (s *Store) OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
|
||||
sl, err := s.getQueueLog(queueName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
if retentionBytes <= 0 || len(sl.messages) == 0 {
|
||||
return sl.head, nil
|
||||
}
|
||||
|
||||
var total int64
|
||||
for i := len(sl.messages) - 1; i >= 0; i-- {
|
||||
total += int64(len(sl.messages[i].GetPayload()))
|
||||
if total > retentionBytes {
|
||||
return sl.head + uint64(i+1), nil
|
||||
}
|
||||
}
|
||||
|
||||
return sl.head, nil
|
||||
}
|
||||
|
||||
// Truncate removes all messages with offset < minOffset.
|
||||
func (s *Store) Truncate(ctx context.Context, queueName string, minOffset uint64) error {
|
||||
sl, err := s.getQueueLog(queueName)
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TimeOffsetProvider exposes time-based offset lookups.
|
||||
type TimeOffsetProvider interface {
|
||||
OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error)
|
||||
}
|
||||
|
||||
// SizeOffsetProvider exposes size-based retention offsets.
|
||||
type SizeOffsetProvider interface {
|
||||
OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error)
|
||||
}
|
||||
+30
-1
@@ -15,16 +15,21 @@ var ErrInvalidConfig = errors.New("invalid queue configuration")
|
||||
type CursorPosition int
|
||||
|
||||
const (
|
||||
CursorDefault CursorPosition = iota // resume from stored position
|
||||
CursorDefault CursorPosition = iota // resume from stored position
|
||||
CursorEarliest // start from beginning
|
||||
CursorLatest // start from end
|
||||
CursorOffset // start from specific offset
|
||||
CursorTimestamp // start from a timestamp
|
||||
)
|
||||
|
||||
// CursorOption specifies cursor positioning for SubscribeWithCursor.
|
||||
type CursorOption struct {
|
||||
Position CursorPosition
|
||||
Offset uint64 // only used when Position == CursorOffset
|
||||
// Timestamp is used when Position == CursorTimestamp.
|
||||
Timestamp time.Time
|
||||
// Mode defines the consumer group mode (queue or stream).
|
||||
Mode ConsumerGroupMode
|
||||
}
|
||||
|
||||
// Queue constants
|
||||
@@ -46,6 +51,11 @@ type QueueConfig struct {
|
||||
Name string
|
||||
Topics []string // Topic patterns that route to this queue (e.g., "sensors/#", "orders/+/created")
|
||||
Reserved bool // True for system queues like "mqtt" that cannot be deleted
|
||||
Type QueueType
|
||||
|
||||
// PrimaryGroup defines the consumer group whose committed offset is used
|
||||
// to report delivery status to stream consumers.
|
||||
PrimaryGroup string
|
||||
|
||||
// Durability
|
||||
Durable bool // true = persists indefinitely, false = ephemeral (cleaned up when no consumers remain)
|
||||
@@ -68,6 +78,14 @@ type QueueConfig struct {
|
||||
HeartbeatTimeout time.Duration
|
||||
}
|
||||
|
||||
// QueueType defines the queue behavior mode.
|
||||
type QueueType string
|
||||
|
||||
const (
|
||||
QueueTypeClassic QueueType = "classic"
|
||||
QueueTypeStream QueueType = "stream"
|
||||
)
|
||||
|
||||
// RetryPolicy defines retry behavior for failed messages.
|
||||
type RetryPolicy struct {
|
||||
MaxRetries int
|
||||
@@ -134,6 +152,7 @@ func DefaultQueueConfig(name string, topics ...string) QueueConfig {
|
||||
Name: name,
|
||||
Topics: topics,
|
||||
Reserved: false,
|
||||
Type: QueueTypeClassic,
|
||||
Durable: true,
|
||||
MaxMessageSize: 10 * 1024 * 1024, // 10MB
|
||||
MaxDepth: 100000,
|
||||
@@ -178,6 +197,8 @@ type QueueConfigInput struct {
|
||||
Name string
|
||||
Topics []string
|
||||
Reserved bool
|
||||
Type QueueType
|
||||
PrimaryGroup string
|
||||
MaxMessageSize int64
|
||||
MaxDepth int64
|
||||
MessageTTL time.Duration
|
||||
@@ -187,6 +208,7 @@ type QueueConfigInput struct {
|
||||
Multiplier float64
|
||||
DLQEnabled bool
|
||||
DLQTopic string
|
||||
Retention RetentionPolicy
|
||||
}
|
||||
|
||||
// FromInput creates a QueueConfig from a simplified input config.
|
||||
@@ -194,6 +216,11 @@ func FromInput(input QueueConfigInput) QueueConfig {
|
||||
cfg := DefaultQueueConfig(input.Name, input.Topics...)
|
||||
cfg.Durable = true
|
||||
cfg.Reserved = input.Reserved
|
||||
if input.Type != "" {
|
||||
cfg.Type = input.Type
|
||||
}
|
||||
cfg.PrimaryGroup = input.PrimaryGroup
|
||||
cfg.Retention = input.Retention
|
||||
|
||||
if input.MaxMessageSize > 0 {
|
||||
cfg.MaxMessageSize = input.MaxMessageSize
|
||||
@@ -254,6 +281,8 @@ func (c *QueueConfig) Validate() error {
|
||||
return ErrInvalidConfig
|
||||
case c.Reserved && !c.Durable:
|
||||
return ErrInvalidConfig
|
||||
case c.Type != "" && c.Type != QueueTypeClassic && c.Type != QueueTypeStream:
|
||||
return ErrInvalidConfig
|
||||
}
|
||||
|
||||
// Validate replication config if enabled
|
||||
|
||||
@@ -43,6 +43,7 @@ type ConsumerGroupState struct {
|
||||
ID string // Group identifier
|
||||
QueueName string // Queue this group consumes from
|
||||
Pattern string // Subscription pattern (e.g., "sensors/#")
|
||||
Mode ConsumerGroupMode
|
||||
|
||||
// Queue cursor state (single cursor per queue, no partitions)
|
||||
Cursor *QueueCursor
|
||||
@@ -59,6 +60,14 @@ type ConsumerGroupState struct {
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ConsumerGroupMode defines how a consumer group is tracked.
|
||||
type ConsumerGroupMode string
|
||||
|
||||
const (
|
||||
GroupModeQueue ConsumerGroupMode = "queue"
|
||||
GroupModeStream ConsumerGroupMode = "stream"
|
||||
)
|
||||
|
||||
// NewConsumerGroupState creates a new consumer group state.
|
||||
func NewConsumerGroupState(queueName, groupID, pattern string) *ConsumerGroupState {
|
||||
now := time.Now()
|
||||
@@ -66,6 +75,7 @@ func NewConsumerGroupState(queueName, groupID, pattern string) *ConsumerGroupSta
|
||||
ID: groupID,
|
||||
QueueName: queueName,
|
||||
Pattern: pattern,
|
||||
Mode: GroupModeQueue,
|
||||
Cursor: &QueueCursor{
|
||||
Cursor: 0,
|
||||
Committed: 0,
|
||||
|
||||
@@ -145,6 +145,12 @@ func (s *Server) handlePublish(w mux.ResponseWriter, r *mux.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
// if !strings.HasPrefix(path, "/mqtt/publish/") {
|
||||
// s.logger.Warn("coap_publish_invalid_path", slog.String("path", path))
|
||||
// s.sendResponse(w, r, codes.BadRequest, "invalid path")
|
||||
// return
|
||||
// }
|
||||
|
||||
topic := strings.TrimPrefix(path, "/mqtt/publish/")
|
||||
if topic == "" {
|
||||
s.logger.Warn("coap_publish_missing_topic")
|
||||
|
||||
+173
-50
@@ -4,22 +4,124 @@
|
||||
package coap
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/cluster"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
"github.com/absmach/fluxmq/storage/memory"
|
||||
piondtls "github.com/pion/dtls/v3"
|
||||
"github.com/plgd-dev/go-coap/v3/message"
|
||||
"github.com/plgd-dev/go-coap/v3/message/codes"
|
||||
"github.com/plgd-dev/go-coap/v3/message/pool"
|
||||
"github.com/plgd-dev/go-coap/v3/mux"
|
||||
)
|
||||
|
||||
type stubConn struct {
|
||||
last *pool.Message
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newStubConn() *stubConn {
|
||||
return &stubConn{done: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (c *stubConn) AcquireMessage(ctx context.Context) *pool.Message {
|
||||
return pool.NewMessage(ctx)
|
||||
}
|
||||
|
||||
func (c *stubConn) ReleaseMessage(*pool.Message) {}
|
||||
|
||||
func (c *stubConn) Ping(ctx context.Context) error { return nil }
|
||||
func (c *stubConn) Get(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) Delete(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) Post(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) Put(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) Observe(ctx context.Context, path string, observeFunc func(notification *pool.Message), opts ...message.Option) (mux.Observation, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) RemoteAddr() net.Addr { return stubAddr("coap-stub") }
|
||||
func (c *stubConn) NetConn() net.Conn { return nil }
|
||||
func (c *stubConn) Context() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
func (c *stubConn) SetContextValue(key interface{}, val interface{}) {}
|
||||
func (c *stubConn) WriteMessage(req *pool.Message) error {
|
||||
c.last = req
|
||||
return nil
|
||||
}
|
||||
func (c *stubConn) Do(req *pool.Message) (*pool.Message, error) { return nil, nil }
|
||||
func (c *stubConn) DoObserve(req *pool.Message, observeFunc func(req *pool.Message)) (mux.Observation, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) Close() error {
|
||||
close(c.done)
|
||||
return nil
|
||||
}
|
||||
func (c *stubConn) Sequence() uint64 { return 0 }
|
||||
func (c *stubConn) Done() <-chan struct{} {
|
||||
return c.done
|
||||
}
|
||||
func (c *stubConn) AddOnClose(func()) {}
|
||||
func (c *stubConn) NewGetRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) NewObserveRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) NewPutRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) NewPostRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *stubConn) NewDeleteRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type stubResponseWriter struct {
|
||||
conn *stubConn
|
||||
msg *pool.Message
|
||||
}
|
||||
|
||||
func (w *stubResponseWriter) SetResponse(code codes.Code, contentFormat message.MediaType, d io.ReadSeeker, opts ...message.Option) error {
|
||||
resp := w.conn.AcquireMessage(context.Background())
|
||||
resp.SetCode(code)
|
||||
resp.SetBody(d)
|
||||
w.conn.last = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *stubResponseWriter) Conn() mux.Conn { return w.conn }
|
||||
func (w *stubResponseWriter) SetMessage(m *pool.Message) {
|
||||
w.msg = m
|
||||
}
|
||||
func (w *stubResponseWriter) Message() *pool.Message { return w.msg }
|
||||
|
||||
type stubAddr string
|
||||
|
||||
func (a stubAddr) Network() string { return "stub" }
|
||||
func (a stubAddr) String() string { return string(a) }
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
stats := broker.NewStats()
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
@@ -45,8 +147,7 @@ func TestNew(t *testing.T) {
|
||||
|
||||
func TestConfig_TLSConfig(t *testing.T) {
|
||||
cfg := Config{
|
||||
Address: ":5684",
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Address: ":5684",
|
||||
TLSConfig: &piondtls.Config{
|
||||
ClientAuth: piondtls.RequireAndVerifyClientCert,
|
||||
},
|
||||
@@ -60,63 +161,85 @@ func TestConfig_TLSConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ListenUDP_ContextCancel(t *testing.T) {
|
||||
func TestHandleHealth(t *testing.T) {
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
stats := broker.NewStats()
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: ":15683", // Use non-standard port for testing
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
server := New(Config{}, b, slog.Default())
|
||||
conn := newStubConn()
|
||||
writer := &stubResponseWriter{conn: conn}
|
||||
|
||||
req := &mux.Message{Message: pool.NewMessage(context.Background())}
|
||||
server.handleHealth(writer, req)
|
||||
|
||||
if conn.last == nil {
|
||||
t.Fatal("expected response message")
|
||||
}
|
||||
if conn.last.Code() != codes.Content {
|
||||
t.Fatalf("expected code %v, got %v", codes.Content, conn.last.Code())
|
||||
}
|
||||
body, err := conn.last.ReadBody()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body: %v", err)
|
||||
}
|
||||
if string(body) != "healthy" {
|
||||
t.Fatalf("expected body %q, got %q", "healthy", string(body))
|
||||
}
|
||||
}
|
||||
|
||||
server := New(cfg, b, slog.Default())
|
||||
func TestHandlePublish(t *testing.T) {
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
stats := broker.NewStats()
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
server := New(Config{}, b, slog.Default())
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- server.Listen(ctx)
|
||||
}()
|
||||
t.Run("missing topic", func(t *testing.T) {
|
||||
conn := newStubConn()
|
||||
writer := &stubResponseWriter{conn: conn}
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
reqMsg := pool.NewMessage(context.Background())
|
||||
reqMsg.MustSetPath("/mqtt/publish")
|
||||
req := &mux.Message{Message: reqMsg}
|
||||
|
||||
// Cancel context to trigger shutdown
|
||||
cancel()
|
||||
server.handlePublish(writer, req)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil error on shutdown, got: %v", err)
|
||||
if conn.last == nil {
|
||||
t.Fatal("expected response message")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Server did not shut down in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_ListenDTLS_InvalidConfig(t *testing.T) {
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
stats := broker.NewStats()
|
||||
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: ":15684",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
TLSConfig: &piondtls.Config{},
|
||||
}
|
||||
|
||||
server := New(cfg, b, slog.Default())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := server.Listen(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid DTLS configuration")
|
||||
}
|
||||
if conn.last.Code() != codes.BadRequest {
|
||||
t.Fatalf("expected code %v, got %v", codes.BadRequest, conn.last.Code())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ok", func(t *testing.T) {
|
||||
conn := newStubConn()
|
||||
writer := &stubResponseWriter{conn: conn}
|
||||
|
||||
reqMsg := pool.NewMessage(context.Background())
|
||||
reqMsg.MustSetPath("/mqtt/publish/test/topic")
|
||||
reqMsg.SetBody(bytes.NewReader([]byte("payload")))
|
||||
req := &mux.Message{Message: reqMsg}
|
||||
|
||||
server.handlePublish(writer, req)
|
||||
|
||||
if conn.last == nil {
|
||||
t.Fatal("expected response message")
|
||||
}
|
||||
if conn.last.Code() != codes.Changed {
|
||||
t.Fatalf("expected code %v, got %v", codes.Changed, conn.last.Code())
|
||||
}
|
||||
body, err := conn.last.ReadBody()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body: %v", err)
|
||||
}
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("expected body %q, got %q", "ok", string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+84
-211
@@ -9,11 +9,12 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/cluster"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
|
||||
"github.com/absmach/fluxmq/storage"
|
||||
)
|
||||
@@ -36,22 +37,6 @@ func (m *mockCluster) AcquireSession(ctx context.Context, clientID, nodeID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) AcquirePartition(ctx context.Context, queueName string, partitionID int, nodeID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ReleasePartition(ctx context.Context, queueName string, partitionID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) EnqueueRemote(ctx context.Context, nodeID, queueName string, payload []byte, properties map[string]string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) RouteQueueMessage(ctx context.Context, nodeID, clientID, queueName, messageID string, payload []byte, properties map[string]string, sequence int64, partitionID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ReleaseSession(ctx context.Context, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -60,6 +45,10 @@ func (m *mockCluster) GetSessionOwner(ctx context.Context, clientID string) (str
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) WatchSessionOwner(ctx context.Context, clientID string) <-chan cluster.OwnershipChange {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) AddSubscription(ctx context.Context, clientID, filter string, qos byte, opts storage.SubscribeOptions) error {
|
||||
return nil
|
||||
}
|
||||
@@ -76,6 +65,30 @@ func (m *mockCluster) GetSubscribersForTopic(ctx context.Context, topic string)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) RegisterQueueConsumer(ctx context.Context, info *cluster.QueueConsumerInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) UnregisterQueueConsumer(ctx context.Context, queueName, groupID, consumerID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ListQueueConsumers(ctx context.Context, queueName string) ([]*cluster.QueueConsumerInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ListQueueConsumersByGroup(ctx context.Context, queueName, groupID string) ([]*cluster.QueueConsumerInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ListAllQueueConsumers(ctx context.Context) ([]*cluster.QueueConsumerInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) ForwardQueuePublish(ctx context.Context, nodeID, topic string, payload []byte, properties map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) Retained() storage.RetainedStore {
|
||||
return nil
|
||||
}
|
||||
@@ -92,6 +105,10 @@ func (m *mockCluster) TakeoverSession(ctx context.Context, clientID, fromNode, t
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) RouteQueueMessage(ctx context.Context, nodeID, clientID, queueName, messageID string, payload []byte, properties map[string]string, sequence int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockCluster) WaitForLeader(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
@@ -108,57 +125,21 @@ func (m *mockCluster) Nodes() []cluster.NodeInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestServerStartStop(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
func TestAddrWithoutListener(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b, nil, slog.Default())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server did not stop in time")
|
||||
server := New(Config{}, b, nil, slog.Default())
|
||||
if server.Addr() != "" {
|
||||
t.Fatalf("expected empty address before listen, got %q", server.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthEndpoint(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b, nil, slog.Default())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr()
|
||||
server := New(Config{}, b, nil, slog.Default())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -186,25 +167,18 @@ func TestHealthEndpoint(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(tt.method, "http://"+addr+"/health", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(tt.method, "http://test/health", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
server.handleHealth(rec, req)
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||||
if rec.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
|
||||
}
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var response HealthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
@@ -237,7 +211,7 @@ func TestReadyEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "single node mode - ready",
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
|
||||
cluster: nil,
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusOK,
|
||||
@@ -245,7 +219,7 @@ func TestReadyEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "cluster not initialized - not ready",
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
|
||||
cluster: &mockCluster{nodeID: ""},
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusServiceUnavailable,
|
||||
@@ -254,7 +228,7 @@ func TestReadyEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "cluster initialized - ready",
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
|
||||
cluster: &mockCluster{nodeID: "node-1", isLeader: true},
|
||||
method: http.MethodGet,
|
||||
expectedStatus: http.StatusOK,
|
||||
@@ -262,7 +236,7 @@ func TestReadyEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "POST request not allowed",
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
|
||||
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
|
||||
cluster: nil,
|
||||
method: http.MethodPost,
|
||||
expectedStatus: http.StatusMethodNotAllowed,
|
||||
@@ -275,40 +249,20 @@ func TestReadyEndpoint(t *testing.T) {
|
||||
defer tt.broker.Close()
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
server := New(Config{}, tt.broker, tt.cluster, slog.Default())
|
||||
|
||||
server := New(cfg, tt.broker, tt.cluster, slog.Default())
|
||||
req := httptest.NewRequest(tt.method, "http://test/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
server.handleReady(rec, req)
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr()
|
||||
|
||||
req, err := http.NewRequest(tt.method, "http://"+addr+"/ready", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||||
if rec.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
|
||||
}
|
||||
|
||||
if tt.expectedStatus == http.StatusOK || tt.expectedStatus == http.StatusServiceUnavailable {
|
||||
var response ReadyResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
@@ -376,38 +330,18 @@ func TestClusterStatusEndpoint(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
server := New(Config{}, b, tt.cluster, slog.Default())
|
||||
|
||||
server := New(cfg, b, tt.cluster, slog.Default())
|
||||
req := httptest.NewRequest(tt.method, "http://test/cluster/status", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
server.handleClusterStatus(rec, req)
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr()
|
||||
|
||||
req, err := http.NewRequest(tt.method, "http://"+addr+"/cluster/status", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
|
||||
if rec.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
|
||||
}
|
||||
|
||||
if tt.checkMethodNotAllowed {
|
||||
@@ -416,7 +350,7 @@ func TestClusterStatusEndpoint(t *testing.T) {
|
||||
|
||||
if tt.expectedStatus == http.StatusOK {
|
||||
var response ClusterStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
@@ -442,41 +376,33 @@ func TestClusterStatusEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestContentTypeHeaders(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
server := New(Config{}, b, nil, slog.Default())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
}{
|
||||
{name: "/health", handler: server.handleHealth},
|
||||
{name: "/ready", handler: server.handleReady},
|
||||
{name: "/cluster/status", handler: server.handleClusterStatus},
|
||||
}
|
||||
|
||||
server := New(cfg, b, nil, slog.Default())
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://test"+tt.name, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
tt.handler(rec, req)
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr()
|
||||
|
||||
endpoints := []string{"/health", "/ready", "/cluster/status"}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
t.Run(endpoint, func(t *testing.T) {
|
||||
resp, err := http.Get("http://" + addr + endpoint)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
contentType := rec.Header().Get("Content-Type")
|
||||
if contentType != "application/json" {
|
||||
t.Errorf("expected Content-Type application/json, got %q", contentType)
|
||||
}
|
||||
|
||||
// Verify it's valid JSON
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read body: %v", err)
|
||||
}
|
||||
@@ -488,56 +414,3 @@ func TestContentTypeHeaders(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b, nil, slog.Default())
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr()
|
||||
|
||||
// Make a request while server is running
|
||||
resp, err := http.Get("http://" + addr + "/health")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get health before shutdown: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Trigger shutdown
|
||||
cancel()
|
||||
|
||||
// Server should stop gracefully
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("shutdown completed with: %v", err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("server did not stop after shutdown timeout")
|
||||
}
|
||||
|
||||
// After shutdown, requests should fail
|
||||
_, err = http.Get("http://" + addr + "/health")
|
||||
if err == nil {
|
||||
t.Error("expected request to fail after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
+145
-122
@@ -7,217 +7,240 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
)
|
||||
|
||||
type stubListener struct {
|
||||
conns chan net.Conn
|
||||
closed chan struct{}
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func newStubListener() *stubListener {
|
||||
return &stubListener{
|
||||
conns: make(chan net.Conn, 16),
|
||||
closed: make(chan struct{}),
|
||||
addr: stubAddr("in-memory"),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *stubListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return nil, net.ErrClosed
|
||||
case conn, ok := <-l.conns:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *stubListener) Close() error {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return nil
|
||||
default:
|
||||
close(l.closed)
|
||||
close(l.conns)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (l *stubListener) Addr() net.Addr { return l.addr }
|
||||
|
||||
func (l *stubListener) push(conn net.Conn) error {
|
||||
select {
|
||||
case <-l.closed:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
l.conns <- conn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type stubAddr string
|
||||
|
||||
func (a stubAddr) Network() string { return "stub" }
|
||||
func (a stubAddr) String() string { return string(a) }
|
||||
|
||||
type trackingConn struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *trackingConn) Close() error {
|
||||
c.closed.Store(true)
|
||||
if c.Conn != nil {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestServerStartStop(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
listener := newStubListener()
|
||||
|
||||
// Start server in goroutine
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
server.mu.Lock()
|
||||
server.listener = listener
|
||||
server.mu.Unlock()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify server started
|
||||
if server.Addr() == nil {
|
||||
t.Fatal("server address is nil after start")
|
||||
}
|
||||
|
||||
// Cancel context to stop server
|
||||
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
|
||||
cancel()
|
||||
|
||||
// Wait for server to stop
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server did not stop in time")
|
||||
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
|
||||
t.Fatalf("unexpected shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
listener := newStubListener()
|
||||
|
||||
// Start server
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
server.mu.Lock()
|
||||
server.listener = listener
|
||||
server.mu.Unlock()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
|
||||
|
||||
// Connect a client
|
||||
conn, err := net.Dial("tcp", server.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
serverConn, clientConn := net.Pipe()
|
||||
if err := listener.push(serverConn); err != nil {
|
||||
t.Fatalf("failed to push connection: %v", err)
|
||||
}
|
||||
clientConn.Close()
|
||||
|
||||
// Trigger shutdown
|
||||
cancel()
|
||||
|
||||
// Close client connection
|
||||
conn.Close()
|
||||
|
||||
// Server should stop gracefully
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("shutdown completed with: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("server did not stop after shutdown timeout")
|
||||
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
|
||||
t.Fatalf("unexpected shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionLimit(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
maxConns := 2
|
||||
maxConns := 1
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
MaxConnections: maxConns,
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := context.Background()
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create max connections
|
||||
conns := make([]net.Conn, maxConns)
|
||||
for i := 0; i < maxConns; i++ {
|
||||
conn, err := net.Dial("tcp", server.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect %d: %v", i, err)
|
||||
}
|
||||
conns[i] = conn
|
||||
s1, c1 := net.Pipe()
|
||||
conn1 := &trackingConn{Conn: s1}
|
||||
if !server.tryAcquireConnectionSlot(ctx, conn1) {
|
||||
t.Fatal("expected first connection to be accepted")
|
||||
}
|
||||
|
||||
// Wait for connections to be accepted
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Try one more connection - should be rejected
|
||||
extraConn, err := net.DialTimeout("tcp", server.Addr().String(), 500*time.Millisecond)
|
||||
if err == nil {
|
||||
extraConn.Close()
|
||||
// Connection might get through briefly before being rejected
|
||||
// Wait and check if it stays open
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
s2, c2 := net.Pipe()
|
||||
conn2 := &trackingConn{Conn: s2}
|
||||
if server.tryAcquireConnectionSlot(ctx, conn2) {
|
||||
t.Fatal("expected second connection to be rejected")
|
||||
}
|
||||
if !conn2.closed.Load() {
|
||||
t.Fatal("expected rejected connection to be closed")
|
||||
}
|
||||
|
||||
// The extra connection should not be handled (or rejected quickly)
|
||||
// We can't easily test this without instrumenting the handler more
|
||||
|
||||
// Clean up
|
||||
for _, conn := range conns {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
server.releaseConnectionSlot()
|
||||
}
|
||||
|
||||
func TestConcurrentConnections(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
ShutdownTimeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
listener := newStubListener()
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
server.mu.Lock()
|
||||
server.listener = listener
|
||||
server.mu.Unlock()
|
||||
|
||||
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
|
||||
|
||||
// Create many concurrent connections
|
||||
numConns := 50
|
||||
numConns := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numConns)
|
||||
|
||||
for i := 0; i < numConns; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := net.Dial("tcp", server.Addr().String())
|
||||
if err != nil {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
if err := listener.push(serverConn); err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte("test"))
|
||||
conn.Close()
|
||||
clientConn.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// All connections should be handled successfully by the broker
|
||||
cancel()
|
||||
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
|
||||
t.Fatalf("unexpected shutdown error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPOptimizations(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
|
||||
func TestDefaultConfigApplied(t *testing.T) {
|
||||
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
|
||||
defer b.Close()
|
||||
|
||||
cfg := Config{
|
||||
Address: "localhost:0",
|
||||
TCPKeepAlive: 15 * time.Second,
|
||||
DisableNoDelay: false,
|
||||
ShutdownTimeout: 1 * time.Second,
|
||||
server := New(Config{}, b)
|
||||
|
||||
if server.config.ShutdownTimeout == 0 {
|
||||
t.Fatal("expected default ShutdownTimeout to be set")
|
||||
}
|
||||
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.Listen(ctx)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Connect and test
|
||||
conn, err := net.Dial("tcp", server.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect: %v", err)
|
||||
if server.config.ReadTimeout == 0 {
|
||||
t.Fatal("expected default ReadTimeout to be set")
|
||||
}
|
||||
if server.config.WriteTimeout == 0 {
|
||||
t.Fatal("expected default WriteTimeout to be set")
|
||||
}
|
||||
if server.config.IdleTimeout == 0 {
|
||||
t.Fatal("expected default IdleTimeout to be set")
|
||||
}
|
||||
if server.config.BufferSize == 0 {
|
||||
t.Fatal("expected default BufferSize to be set")
|
||||
}
|
||||
if server.config.TCPKeepAlive == 0 {
|
||||
t.Fatal("expected default TCPKeepAlive to be set")
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// TCP options are applied successfully if connection is accepted
|
||||
}
|
||||
|
||||
+130
-332
@@ -6,63 +6,47 @@ package tcp
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/cluster"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
v3 "github.com/absmach/fluxmq/mqtt/packets/v3"
|
||||
"github.com/absmach/fluxmq/storage/memory"
|
||||
)
|
||||
|
||||
func TestTLS_BasicConnection(t *testing.T) {
|
||||
certs := GenerateTestCerts(t)
|
||||
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
func newTestBroker(t *testing.T) *broker.Broker {
|
||||
t.Helper()
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
b := broker.NewBroker(nil, nil, logger, nil, nil, nil, nil, config.SessionConfig{})
|
||||
t.Cleanup(func() { _ = b.Close() })
|
||||
return b
|
||||
}
|
||||
|
||||
// Create server with TLS
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
|
||||
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
|
||||
func runHandleConnection(s *Server, conn net.Conn) {
|
||||
s.wg.Add(1)
|
||||
go s.handleConnection(context.Background(), conn)
|
||||
}
|
||||
|
||||
cfg := Config{
|
||||
Address: "127.0.0.1:0", // Random port
|
||||
TLSConfig: tlsConfig,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logger: nullLogger,
|
||||
}
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
func waitForConnections(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr().String()
|
||||
|
||||
// Connect with TLS client
|
||||
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
|
||||
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect with TLS: %v", err)
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("connection handler did not exit in time")
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
// Verify TLS handshake completed
|
||||
if err := conn.Handshake(); err != nil {
|
||||
t.Fatalf("TLS handshake failed: %v", err)
|
||||
}
|
||||
|
||||
// Send MQTT CONNECT packet
|
||||
func mqttHandshake(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
connectPkt := &v3.Connect{
|
||||
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
|
||||
ProtocolName: "MQTT",
|
||||
@@ -72,339 +56,153 @@ func TestTLS_BasicConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
if err := connectPkt.Pack(conn); err != nil {
|
||||
t.Fatalf("Failed to send CONNECT: %v", err)
|
||||
t.Fatalf("failed to send CONNECT: %v", err)
|
||||
}
|
||||
|
||||
// Read CONNACK
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
connack, err := v3.ReadPacket(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read CONNACK: %v", err)
|
||||
t.Fatalf("failed to read CONNACK: %v", err)
|
||||
}
|
||||
|
||||
if connack.Type() != v3.ConnAckType {
|
||||
t.Fatalf("Expected CONNACK, got %v", connack.Type())
|
||||
t.Fatalf("expected CONNACK, got %v", connack.Type())
|
||||
}
|
||||
|
||||
// Send DISCONNECT packet
|
||||
disconnectPkt := &v3.Disconnect{
|
||||
FixedHeader: v3.FixedHeader{PacketType: v3.DisconnectType},
|
||||
}
|
||||
disconnectPkt.Pack(conn)
|
||||
conn.Close()
|
||||
_ = disconnectPkt.Pack(conn)
|
||||
}
|
||||
|
||||
// Give broker time to process disconnect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
func tlsHandshakeWithTimeout(conn *tls.Conn, timeout time.Duration) error {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- conn.Handshake()
|
||||
}()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("Server shutdown with error: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("Server shutdown timeout")
|
||||
return err
|
||||
case <-time.After(timeout):
|
||||
return errors.New("handshake timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS_BasicConnection(t *testing.T) {
|
||||
certs := GenerateTestCerts(t)
|
||||
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
clientTLS := LoadClientTLSConfig(t, certs, false)
|
||||
clientTLS.ServerName = "localhost"
|
||||
|
||||
b := newTestBroker(t)
|
||||
server := New(Config{TLSConfig: serverTLS}, b)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
tlsServer := tls.Server(serverConn, serverTLS)
|
||||
tlsClient := tls.Client(clientConn, clientTLS)
|
||||
|
||||
runHandleConnection(server, tlsServer)
|
||||
|
||||
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err != nil {
|
||||
t.Fatalf("TLS handshake failed: %v", err)
|
||||
}
|
||||
mqttHandshake(t, tlsClient)
|
||||
_ = tlsClient.Close()
|
||||
waitForConnections(t, server)
|
||||
}
|
||||
|
||||
func TestTLS_RequireClientCert(t *testing.T) {
|
||||
certs := GenerateTestCerts(t)
|
||||
tlsConfig := LoadServerTLSConfig(t, certs, tls.RequireAndVerifyClientCert)
|
||||
serverTLS := LoadServerTLSConfig(t, certs, tls.RequireAndVerifyClientCert)
|
||||
|
||||
// Verify TLS config is set up correctly
|
||||
t.Logf("Server TLS ClientAuth: %v (expected: %v)", tlsConfig.ClientAuth, tls.RequireAndVerifyClientCert)
|
||||
if tlsConfig.ClientAuth != tls.RequireAndVerifyClientCert {
|
||||
t.Fatalf("Server TLS config ClientAuth not set correctly")
|
||||
}
|
||||
|
||||
// Create server with TLS requiring client cert
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
b := broker.NewBroker(store, cl, logger, nil, nil, nil, nil)
|
||||
|
||||
cfg := Config{
|
||||
Address: "127.0.0.1:0",
|
||||
TLSConfig: tlsConfig,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logger: logger,
|
||||
}
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr().String()
|
||||
|
||||
// Test 1: Connection without client cert should fail
|
||||
t.Run("NoClientCert", func(t *testing.T) {
|
||||
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
|
||||
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
|
||||
if err != nil {
|
||||
// Expected - connection failed during handshake
|
||||
t.Logf("Connection correctly rejected during dial: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
b := newTestBroker(t)
|
||||
server := New(Config{TLSConfig: serverTLS}, b)
|
||||
clientTLS := LoadClientTLSConfig(t, certs, false)
|
||||
clientTLS.ServerName = "localhost"
|
||||
serverConn, clientConn := net.Pipe()
|
||||
tlsServer := tls.Server(serverConn, serverTLS)
|
||||
tlsClient := tls.Client(clientConn, clientTLS)
|
||||
|
||||
// Client-side handshake might succeed, but server should close the connection
|
||||
// Try to read which should fail when server closes due to missing client cert
|
||||
buf := make([]byte, 1)
|
||||
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err = conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Logf("Connection correctly rejected: %v", err)
|
||||
return
|
||||
}
|
||||
runHandleConnection(server, tlsServer)
|
||||
|
||||
t.Fatal("Expected connection to fail without client certificate, but it succeeded")
|
||||
err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second)
|
||||
if err == nil {
|
||||
connectPkt := &v3.Connect{
|
||||
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
|
||||
ProtocolName: "MQTT",
|
||||
ProtocolVersion: 4,
|
||||
CleanSession: true,
|
||||
ClientID: "no-cert-client",
|
||||
}
|
||||
err = connectPkt.Pack(tlsClient)
|
||||
if err == nil {
|
||||
tlsClient.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err = v3.ReadPacket(tlsClient)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
t.Fatal("expected connection to be rejected without client cert")
|
||||
}
|
||||
_ = tlsClient.Close()
|
||||
waitForConnections(t, server)
|
||||
})
|
||||
|
||||
// Test 2: Connection with client cert should succeed
|
||||
t.Run("WithClientCert", func(t *testing.T) {
|
||||
clientTLSConfig := LoadClientTLSConfig(t, certs, true)
|
||||
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect with client cert: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
b := newTestBroker(t)
|
||||
server := New(Config{TLSConfig: serverTLS}, b)
|
||||
clientTLS := LoadClientTLSConfig(t, certs, true)
|
||||
clientTLS.ServerName = "localhost"
|
||||
serverConn, clientConn := net.Pipe()
|
||||
tlsServer := tls.Server(serverConn, serverTLS)
|
||||
tlsClient := tls.Client(clientConn, clientTLS)
|
||||
|
||||
// Verify TLS handshake with client cert
|
||||
if err := conn.Handshake(); err != nil {
|
||||
runHandleConnection(server, tlsServer)
|
||||
|
||||
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err != nil {
|
||||
t.Fatalf("TLS handshake failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify connection state has peer certificates
|
||||
state := conn.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
t.Fatal("Server did not receive client certificate")
|
||||
}
|
||||
mqttHandshake(t, tlsClient)
|
||||
_ = tlsClient.Close()
|
||||
waitForConnections(t, server)
|
||||
})
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("Server shutdown with error: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("Server shutdown timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS_InvalidCert(t *testing.T) {
|
||||
func TestTLS_UntrustedServer(t *testing.T) {
|
||||
certs := GenerateTestCerts(t)
|
||||
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
|
||||
// Create server
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
|
||||
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
|
||||
b := newTestBroker(t)
|
||||
server := New(Config{TLSConfig: serverTLS}, b)
|
||||
|
||||
cfg := Config{
|
||||
Address: "127.0.0.1:0",
|
||||
TLSConfig: tlsConfig,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logger: nullLogger,
|
||||
}
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr().String()
|
||||
|
||||
// Try to connect without trusting the server's CA
|
||||
insecureTLSConfig := &tls.Config{
|
||||
InsecureSkipVerify: false, // Explicitly don't skip verification
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, insecureTLSConfig)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
t.Fatal("Expected connection to fail with unverified certificate")
|
||||
}
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("Server shutdown with error: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("Server shutdown timeout")
|
||||
serverConn, clientConn := net.Pipe()
|
||||
tlsServer := tls.Server(serverConn, serverTLS)
|
||||
tlsClient := tls.Client(clientConn, &tls.Config{InsecureSkipVerify: false, ServerName: "localhost"})
|
||||
|
||||
runHandleConnection(server, tlsServer)
|
||||
|
||||
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err == nil {
|
||||
t.Fatal("expected TLS handshake to fail for untrusted server cert")
|
||||
}
|
||||
_ = tlsClient.Close()
|
||||
waitForConnections(t, server)
|
||||
}
|
||||
|
||||
func TestTLS_MinVersion(t *testing.T) {
|
||||
func TestTLS_MinVersionConfig(t *testing.T) {
|
||||
certs := GenerateTestCerts(t)
|
||||
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
|
||||
// Verify minimum TLS version is enforced
|
||||
if tlsConfig.MinVersion != tls.VersionTLS12 {
|
||||
t.Fatalf("Expected MinVersion to be TLS 1.2, got %v", tlsConfig.MinVersion)
|
||||
}
|
||||
|
||||
// Create server
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
|
||||
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
|
||||
|
||||
cfg := Config{
|
||||
Address: "127.0.0.1:0",
|
||||
TLSConfig: tlsConfig,
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logger: nullLogger,
|
||||
}
|
||||
server := New(cfg, b)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr().String()
|
||||
|
||||
// Try to connect with TLS 1.1 (should fail)
|
||||
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
|
||||
clientTLSConfig.MaxVersion = tls.VersionTLS11
|
||||
|
||||
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
// Note: This might not fail in all Go versions as the client's MaxVersion
|
||||
// might be overridden. The important thing is the server enforces MinVersion.
|
||||
t.Log("Note: Client was able to connect with TLS 1.1 (client-side compatibility)")
|
||||
} else {
|
||||
t.Logf("Connection correctly rejected with TLS 1.1: %v", err)
|
||||
}
|
||||
|
||||
// Verify TLS 1.2+ works
|
||||
clientTLSConfig.MaxVersion = tls.VersionTLS13
|
||||
clientTLSConfig.MinVersion = tls.VersionTLS12
|
||||
|
||||
conn, err = tls.Dial("tcp", addr, clientTLSConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect with TLS 1.2+: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("Server shutdown with error: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("Server shutdown timeout")
|
||||
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
|
||||
if serverTLS.MinVersion != tls.VersionTLS12 {
|
||||
t.Fatalf("expected MinVersion to be TLS 1.2, got %v", serverTLS.MinVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS_NoTLS(t *testing.T) {
|
||||
// Verify server works without TLS when TLSConfig is nil
|
||||
store := memory.New()
|
||||
cl := cluster.NewNoopCluster("test-node")
|
||||
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
|
||||
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
|
||||
b := newTestBroker(t)
|
||||
server := New(Config{}, b)
|
||||
|
||||
cfg := Config{
|
||||
Address: "127.0.0.1:0",
|
||||
TLSConfig: nil, // No TLS
|
||||
ShutdownTimeout: 5 * time.Second,
|
||||
Logger: nullLogger,
|
||||
}
|
||||
server := New(cfg, b)
|
||||
serverConn, clientConn := net.Pipe()
|
||||
runHandleConnection(server, serverConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- server.Listen(ctx)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
addr := server.Addr().String()
|
||||
|
||||
// Connect without TLS
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect without TLS: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send MQTT CONNECT packet
|
||||
connectPkt := &v3.Connect{
|
||||
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
|
||||
ProtocolName: "MQTT",
|
||||
ProtocolVersion: 4,
|
||||
CleanSession: true,
|
||||
ClientID: "plain-test-client",
|
||||
}
|
||||
|
||||
if err := connectPkt.Pack(conn); err != nil {
|
||||
t.Fatalf("Failed to send CONNECT: %v", err)
|
||||
}
|
||||
|
||||
// Read CONNACK
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
connack, err := v3.ReadPacket(conn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read CONNACK: %v", err)
|
||||
}
|
||||
|
||||
if connack.Type() != v3.ConnAckType {
|
||||
t.Fatalf("Expected CONNACK, got %v", connack.Type())
|
||||
}
|
||||
|
||||
// Send DISCONNECT packet
|
||||
disconnectPkt2 := &v3.Disconnect{
|
||||
FixedHeader: v3.FixedHeader{PacketType: v3.DisconnectType},
|
||||
}
|
||||
disconnectPkt2.Pack(conn)
|
||||
conn.Close()
|
||||
|
||||
// Give broker time to process disconnect
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Shutdown server
|
||||
cancel()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Logf("Server shutdown with error: %v", err)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("Server shutdown timeout")
|
||||
}
|
||||
mqttHandshake(t, clientConn)
|
||||
_ = clientConn.Close()
|
||||
waitForConnections(t, server)
|
||||
}
|
||||
|
||||
+3
-2
@@ -12,8 +12,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/fluxmq/broker"
|
||||
"github.com/absmach/fluxmq/cluster"
|
||||
"github.com/absmach/fluxmq/config"
|
||||
"github.com/absmach/fluxmq/mqtt/broker"
|
||||
"github.com/absmach/fluxmq/server/tcp"
|
||||
"github.com/absmach/fluxmq/storage/badger"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -193,7 +194,7 @@ func (tc *TestCluster) startNode(node *TestNode, bootstrap bool, peerTransports
|
||||
store.Close()
|
||||
return fmt.Errorf("failed to create cluster: %w", err)
|
||||
}
|
||||
b := broker.NewBroker(store, clust, nullLogger, nil, nil, nil, nil)
|
||||
b := broker.NewBroker(store, clust, nullLogger, nil, nil, nil, nil, config.SessionConfig{})
|
||||
|
||||
// Wire broker as message handler (includes session management)
|
||||
clust.SetMessageHandler(b)
|
||||
|
||||
Reference in New Issue
Block a user