Add support for cursor for AMQP consumers

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
dusan
2026-02-04 12:25:45 +01:00
parent 60b6d958b5
commit f695d29c53
29 changed files with 2097 additions and 766 deletions
+3
View File
@@ -33,6 +33,9 @@ func PrefixedClientID(connID string) string {
type QueueManager interface {
Start(ctx context.Context) error
Stop() error
CreateQueue(ctx context.Context, config qtypes.QueueConfig) error
UpdateQueue(ctx context.Context, config qtypes.QueueConfig) error
GetQueue(ctx context.Context, queueName string) (*qtypes.QueueConfig, error)
Publish(ctx context.Context, topic string, payload []byte, properties map[string]string) error
Subscribe(ctx context.Context, queueName, pattern, clientID, groupID, proxyNodeID string) error
SubscribeWithCursor(ctx context.Context, queueName, pattern, clientID, groupID, proxyNodeID string, cursor *qtypes.CursorOption) error
+302 -6
View File
@@ -8,11 +8,15 @@ import (
"context"
"fmt"
"io"
"math"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/absmach/fluxmq/amqp091/codec"
qtypes "github.com/absmach/fluxmq/queue/types"
"github.com/absmach/fluxmq/storage"
"github.com/absmach/fluxmq/topics"
)
@@ -28,6 +32,12 @@ type consumer struct {
exclusive bool
}
type queueInfo struct {
name string
queueType string
args map[string]interface{}
}
// exchange represents a declared exchange (in-memory, per-connection for now).
type exchange struct {
name string
@@ -52,7 +62,7 @@ type Channel struct {
// Exchange/queue/binding state (local to this connection for non-durable)
exchanges map[string]*exchange
queues map[string]bool // set of declared queue names
queues map[string]*queueInfo // declared queues by name
bindings []binding
exchangeMu sync.RWMutex
@@ -115,7 +125,7 @@ func newChannel(c *Connection, id uint16) *Channel {
conn: c,
id: id,
exchanges: make(map[string]*exchange),
queues: make(map[string]bool),
queues: make(map[string]*queueInfo),
consumers: make(map[string]*consumer),
unacked: make(map[uint64]*unackedDelivery),
flow: true,
@@ -335,6 +345,21 @@ func (ch *Channel) completePublish() {
}
}
// RabbitMQ-style stream queue publish: default exchange with routingKey == queue name.
if exchangeName == "" && ch.isStreamQueue(routingKey) {
qm := ch.conn.broker.getQueueManager()
if qm != nil {
queueTopic := "$queue/" + routingKey
if err := qm.Publish(context.Background(), queueTopic, body, props); err != nil {
ch.conn.logger.Error("queue publish failed", "queue", routingKey, "error", err)
}
if ch.confirmMode {
ch.sendPublisherAck()
}
return
}
}
// Check if this targets a queue via exchange bindings
isQueuePublish := false
ch.exchangeMu.RLock()
@@ -549,12 +574,26 @@ func (ch *Channel) sendDelivery(cons *consumer, topic string, payload []byte, pr
return err
}
headers := make(map[string]interface{})
for k, v := range props {
switch k {
case "content-type", "content-encoding", "correlation-id", "reply-to", "message-id", "type":
continue
default:
headers[k] = v
}
}
if len(headers) == 0 {
headers = nil
}
properties := codec.BasicProperties{
ContentType: props["content-type"],
CorrelationID: props["correlation-id"],
ReplyTo: props["reply-to"],
MessageID: props["message-id"],
Type: props["type"],
Headers: headers,
}
headerFrame, err := buildContentHeaderFrame(ch.id, uint64(len(payload)), properties)
@@ -769,10 +808,45 @@ func (ch *Channel) handleQueueDeclare(m *codec.QueueDeclare) error {
m.Queue = fmt.Sprintf("amq.gen-%s-%d", ch.conn.connID, seq)
}
queueType := extractQueueType(m.Arguments)
ch.exchangeMu.Lock()
ch.queues[m.Queue] = true
ch.queues[m.Queue] = &queueInfo{
name: m.Queue,
queueType: queueType,
args: m.Arguments,
}
ch.exchangeMu.Unlock()
if queueType == string(qtypes.QueueTypeStream) {
qm := ch.conn.broker.getQueueManager()
if qm != nil {
queueTopicPattern := "$queue/" + m.Queue + "/#"
var cfg qtypes.QueueConfig
if m.Durable {
cfg = qtypes.DefaultQueueConfig(m.Queue, queueTopicPattern)
} else {
cfg = qtypes.DefaultEphemeralQueueConfig(m.Queue, queueTopicPattern)
}
cfg.Type = qtypes.QueueTypeStream
cfg.Retention = extractStreamRetention(m.Arguments)
if err := qm.CreateQueue(context.Background(), cfg); err != nil {
// If it already exists, attempt to update retention/type only.
if existing, err := qm.GetQueue(context.Background(), m.Queue); err == nil && existing != nil {
existing.Type = qtypes.QueueTypeStream
existing.Retention = cfg.Retention
if !m.Durable {
existing.Durable = false
if existing.ExpiresAfter == 0 {
existing.ExpiresAfter = 5 * time.Minute
}
}
_ = qm.UpdateQueue(context.Background(), *existing)
}
}
}
}
// Auto-bind queue to default exchange with routing key = queue name
ch.exchangeMu.Lock()
ch.bindings = append(ch.bindings, binding{
@@ -877,7 +951,19 @@ func (ch *Channel) handleBasicConsume(m *codec.BasicConsume) error {
if isQueue {
queueName, pattern = parseQueueFilter(queueFilter)
}
queueInfo := ch.getQueueInfo(queueFilter)
streamCursor, hasStreamOffset := extractStreamOffset(m.Arguments)
isStream := (queueInfo != nil && queueInfo.queueType == string(qtypes.QueueTypeStream)) || hasStreamOffset
if isStream && !isQueue {
queueName = queueFilter
pattern = ""
}
groupID := extractConsumerGroup(m.Arguments)
if isStream && groupID == "" {
groupID = tag
}
ch.consumersMu.Lock()
if _, exists := ch.consumers[tag]; exists {
@@ -900,15 +986,25 @@ func (ch *Channel) handleBasicConsume(m *codec.BasicConsume) error {
// Subscribe to the queue via queue manager
qm := ch.conn.broker.getQueueManager()
if qm != nil && isQueue && queueName != "" {
if qm != nil && (isQueue || isStream) && queueName != "" {
clientID := PrefixedClientID(ch.conn.connID)
if err := qm.Subscribe(context.Background(), queueName, pattern, clientID, groupID, ""); err != nil {
subGroupID := groupID
if isStream {
cursor := streamCursor
if cursor == nil {
cursor = &qtypes.CursorOption{Position: qtypes.CursorLatest}
}
cursor.Mode = qtypes.GroupModeStream
if err := qm.SubscribeWithCursor(context.Background(), queueName, pattern, clientID, subGroupID, "", cursor); err != nil {
ch.conn.logger.Error("queue subscribe with cursor failed", "queue", queueName, "error", err)
}
} else if err := qm.Subscribe(context.Background(), queueName, pattern, clientID, subGroupID, ""); err != nil {
ch.conn.logger.Error("queue subscribe failed", "queue", queueName, "error", err)
}
}
// Subscribe via the topic router for pub/sub delivery (non-queue topics).
if !isQueue {
if !isQueue && !isStream {
connID := ch.conn.connID
ch.conn.broker.router.Subscribe(connID, queueFilter, 1, storage.SubscribeOptions{})
}
@@ -1172,3 +1268,203 @@ func extractConsumerGroup(args map[string]interface{}) string {
return fmt.Sprintf("%v", v)
}
}
func (ch *Channel) getQueueInfo(name string) *queueInfo {
ch.exchangeMu.RLock()
defer ch.exchangeMu.RUnlock()
return ch.queues[name]
}
func (ch *Channel) isStreamQueue(name string) bool {
if info := ch.getQueueInfo(name); info != nil && info.queueType == string(qtypes.QueueTypeStream) {
return true
}
if strings.HasPrefix(name, "$queue/") {
base := strings.TrimPrefix(name, "$queue/")
if info := ch.getQueueInfo(base); info != nil && info.queueType == string(qtypes.QueueTypeStream) {
return true
}
}
return false
}
func extractQueueType(args map[string]interface{}) string {
if len(args) == 0 {
return string(qtypes.QueueTypeClassic)
}
val, ok := args["x-queue-type"]
if !ok {
return string(qtypes.QueueTypeClassic)
}
switch v := val.(type) {
case string:
if v == "" {
return string(qtypes.QueueTypeClassic)
}
return strings.ToLower(v)
case []byte:
if len(v) == 0 {
return string(qtypes.QueueTypeClassic)
}
return strings.ToLower(string(v))
default:
return string(qtypes.QueueTypeClassic)
}
}
func extractStreamRetention(args map[string]interface{}) qtypes.RetentionPolicy {
var policy qtypes.RetentionPolicy
if len(args) == 0 {
return policy
}
if val, ok := args["x-max-age"]; ok {
if d, ok := parseDurationArg(val); ok {
policy.RetentionTime = d
}
}
if val, ok := args["x-max-length-bytes"]; ok {
if n, ok := parseInt64Arg(val); ok {
policy.RetentionBytes = n
}
}
if val, ok := args["x-max-length"]; ok {
if n, ok := parseInt64Arg(val); ok {
policy.RetentionMessages = n
}
}
return policy
}
func extractStreamOffset(args map[string]interface{}) (*qtypes.CursorOption, bool) {
if len(args) == 0 {
return nil, false
}
val, ok := args["x-stream-offset"]
if !ok {
return nil, false
}
switch v := val.(type) {
case string:
return parseStreamOffsetString(v)
case []byte:
return parseStreamOffsetString(string(v))
case int:
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
case int64:
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
case uint64:
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: v}, true
case uint32:
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: uint64(v)}, true
case time.Time:
return &qtypes.CursorOption{Position: qtypes.CursorTimestamp, Timestamp: v}, true
default:
return nil, false
}
}
func parseStreamOffsetString(val string) (*qtypes.CursorOption, bool) {
if val == "" {
return nil, false
}
v := strings.ToLower(strings.TrimSpace(val))
switch v {
case "first":
return &qtypes.CursorOption{Position: qtypes.CursorEarliest}, true
case "last", "next":
return &qtypes.CursorOption{Position: qtypes.CursorLatest}, true
}
if strings.HasPrefix(v, "offset=") {
v = strings.TrimPrefix(v, "offset=")
}
if strings.HasPrefix(v, "timestamp=") {
raw := strings.TrimPrefix(v, "timestamp=")
if ts, ok := parseUnixTimestamp(raw); ok {
return &qtypes.CursorOption{Position: qtypes.CursorTimestamp, Timestamp: ts}, true
}
}
if off, err := strconv.ParseUint(v, 10, 64); err == nil {
return &qtypes.CursorOption{Position: qtypes.CursorOffset, Offset: off}, true
}
return nil, false
}
func parseUnixTimestamp(raw string) (time.Time, bool) {
if raw == "" {
return time.Time{}, false
}
val, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return time.Time{}, false
}
if val > 1e12 {
return time.UnixMilli(val), true
}
return time.Unix(val, 0), true
}
func parseDurationArg(val any) (time.Duration, bool) {
switch v := val.(type) {
case time.Duration:
return v, true
case string:
trimmed := strings.TrimSpace(v)
if trimmed == "" {
return 0, false
}
if d, err := time.ParseDuration(trimmed); err == nil {
return d, true
}
upper := strings.ToUpper(trimmed)
if strings.HasSuffix(upper, "D") {
num := strings.TrimSuffix(upper, "D")
if f, err := strconv.ParseFloat(num, 64); err == nil {
return time.Duration(f * float64(24*time.Hour)), true
}
}
if strings.HasSuffix(upper, "W") {
num := strings.TrimSuffix(upper, "W")
if f, err := strconv.ParseFloat(num, 64); err == nil {
return time.Duration(f * float64(7*24*time.Hour)), true
}
}
case int:
return time.Duration(v) * time.Second, true
case int64:
return time.Duration(v) * time.Second, true
case uint64:
return time.Duration(v) * time.Second, true
}
return 0, false
}
func parseInt64Arg(val any) (int64, bool) {
switch v := val.(type) {
case int64:
return v, true
case int:
return int64(v), true
case uint64:
if v > math.MaxInt64 {
return 0, false
}
return int64(v), true
case uint32:
return int64(v), true
case string:
if n, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64); err == nil {
return n, true
}
case []byte:
if n, err := strconv.ParseInt(strings.TrimSpace(string(v)), 10, 64); err == nil {
return n, true
}
}
return 0, false
}
+23
View File
@@ -11,6 +11,7 @@ import (
"testing"
"github.com/absmach/fluxmq/amqp091/codec"
qtypes "github.com/absmach/fluxmq/queue/types"
)
func newTestChannel(t *testing.T) (*Channel, *bytes.Buffer) {
@@ -262,3 +263,25 @@ func TestExchangeNotFoundOnPublish(t *testing.T) {
t.Fatalf("expected NotFound, got %d", closeMsg.ReplyCode)
}
}
func TestParseStreamOffsetString(t *testing.T) {
first, ok := parseStreamOffsetString("first")
if !ok || first.Position != qtypes.CursorEarliest {
t.Fatalf("expected earliest, got %+v", first)
}
last, ok := parseStreamOffsetString("last")
if !ok || last.Position != qtypes.CursorLatest {
t.Fatalf("expected latest, got %+v", last)
}
offset, ok := parseStreamOffsetString("offset=42")
if !ok || offset.Position != qtypes.CursorOffset || offset.Offset != 42 {
t.Fatalf("expected offset 42, got %+v", offset)
}
ts, ok := parseStreamOffsetString("timestamp=1700000000")
if !ok || ts.Position != qtypes.CursorTimestamp || ts.Timestamp.IsZero() {
t.Fatalf("expected timestamp, got %+v", ts)
}
}
+346
View File
@@ -4,6 +4,7 @@
package amqp091
import (
"math"
"strconv"
"strings"
"time"
@@ -18,6 +19,31 @@ type QueuePublishOptions struct {
Properties map[string]string // Optional message properties
}
// StreamQueueOptions configures a stream queue declaration.
type StreamQueueOptions struct {
Name string
Durable bool
AutoDelete bool
Exclusive bool
NoWait bool
MaxAge string // e.g. "7D", "1h"
MaxLengthBytes int64
MaxLengthMessages int64
}
// StreamConsumeOptions configures a stream queue subscription.
type StreamConsumeOptions struct {
QueueName string
ConsumerGroup string
Offset string // "first", "last", "next", "offset=123", "timestamp=..."
AutoAck bool
Exclusive bool
NoLocal bool
NoWait bool
ConsumerTag string
Arguments amqp091.Table
}
// QueueMessageHandler is called when a queue message is received.
type QueueMessageHandler func(msg *QueueMessage)
@@ -49,6 +75,34 @@ func (qm *QueueMessage) Reject() error {
})
}
// StreamOffset returns the stream offset if present.
func (qm *QueueMessage) StreamOffset() (uint64, bool) {
return headerUint64(qm.Headers, "x-stream-offset")
}
// StreamTimestamp returns the stream timestamp (unix millis) if present.
func (qm *QueueMessage) StreamTimestamp() (int64, bool) {
if v, ok := headerInt64(qm.Headers, "x-stream-timestamp"); ok {
return v, true
}
return 0, false
}
// WorkAcked reports whether the primary work group has acknowledged this offset.
func (qm *QueueMessage) WorkAcked() (bool, bool) {
return headerBool(qm.Headers, "x-work-acked")
}
// WorkCommittedOffset returns the primary group's committed offset if present.
func (qm *QueueMessage) WorkCommittedOffset() (uint64, bool) {
return headerUint64(qm.Headers, "x-work-committed-offset")
}
// WorkGroup returns the primary work group name if present.
func (qm *QueueMessage) WorkGroup() (string, bool) {
return headerString(qm.Headers, "x-work-group")
}
func (qm *QueueMessage) withChannelLock(fn func() error) error {
if qm.client == nil {
return fn()
@@ -111,6 +165,66 @@ func (c *Client) PublishToQueueWithOptions(opts *QueuePublishOptions) error {
return c.publish("", queueTopic, publishing, false, false)
}
// PublishToStream publishes a message to a stream queue (RabbitMQ-style queue name).
func (c *Client) PublishToStream(queueName string, payload []byte, props map[string]string) error {
if !c.connected.Load() {
return ErrNotConnected
}
if queueName == "" {
return ErrInvalidQueueName
}
publishing := amqp091.Publishing{
Timestamp: time.Now(),
Body: payload,
}
applyProperties(&publishing, props)
return c.publish("", queueName, publishing, false, false)
}
// DeclareStreamQueue declares a stream queue with retention settings.
func (c *Client) DeclareStreamQueue(opts *StreamQueueOptions) (string, error) {
if !c.connected.Load() {
return "", ErrNotConnected
}
if opts == nil {
return "", ErrInvalidQueueName
}
args := amqp091.Table{
"x-queue-type": "stream",
}
if opts.MaxAge != "" {
args["x-max-age"] = opts.MaxAge
}
if opts.MaxLengthBytes > 0 {
args["x-max-length-bytes"] = opts.MaxLengthBytes
}
if opts.MaxLengthMessages > 0 {
args["x-max-length"] = opts.MaxLengthMessages
}
ch, err := c.channel()
if err != nil {
return "", err
}
c.chMu.Lock()
q, err := ch.QueueDeclare(
opts.Name,
opts.Durable,
opts.AutoDelete,
opts.Exclusive,
opts.NoWait,
args,
)
c.chMu.Unlock()
if err != nil {
return "", err
}
return q.Name, nil
}
// SubscribeToQueue subscribes to a durable queue with a consumer group.
// The queueName should NOT include the "$queue/" prefix - it will be added automatically.
// The handler will be called for each message received from the queue.
@@ -154,6 +268,51 @@ func (c *Client) SubscribeToQueue(queueName, consumerGroup string, handler Queue
return nil
}
// SubscribeToStream subscribes to a stream queue with cursor control.
func (c *Client) SubscribeToStream(opts *StreamConsumeOptions, handler QueueMessageHandler) error {
if !c.connected.Load() {
return ErrNotConnected
}
if opts == nil || opts.QueueName == "" {
return ErrInvalidQueueName
}
if handler == nil {
return ErrNilHandler
}
queueName := opts.QueueName
c.subsMu.Lock()
if _, exists := c.queueSubs[queueName]; exists {
c.subsMu.Unlock()
return ErrAlreadySubscribed
}
consumerTag := opts.ConsumerTag
if consumerTag == "" {
consumerTag = "ctag-" + strings.ReplaceAll(queueName, "/", "-") + "-" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
sub := &queueSubscription{
queueName: queueName,
queueTopic: queueName,
consumerGroup: opts.ConsumerGroup,
consumerTag: consumerTag,
handler: handler,
done: make(chan struct{}),
}
c.queueSubs[queueName] = sub
c.subsMu.Unlock()
if err := c.subscribeStream(sub, opts); err != nil {
c.subsMu.Lock()
delete(c.queueSubs, queueName)
c.subsMu.Unlock()
return err
}
return nil
}
// UnsubscribeFromQueue unsubscribes from a durable queue.
// The queueName should NOT include the "$queue/" prefix - it will be added automatically.
func (c *Client) UnsubscribeFromQueue(queueName string) error {
@@ -182,6 +341,31 @@ func (c *Client) UnsubscribeFromQueue(queueName string) error {
return ch.Cancel(sub.consumerTag, false)
}
// UnsubscribeFromStream unsubscribes from a stream queue.
func (c *Client) UnsubscribeFromStream(queueName string) error {
c.subsMu.Lock()
sub, ok := c.queueSubs[queueName]
if ok {
delete(c.queueSubs, queueName)
}
c.subsMu.Unlock()
if !ok {
return nil
}
sub.close()
ch, err := c.channel()
if err != nil {
return err
}
c.chMu.Lock()
defer c.chMu.Unlock()
return ch.Cancel(sub.consumerTag, false)
}
func (c *Client) subscribeQueue(sub *queueSubscription) error {
ch, err := c.channel()
if err != nil {
@@ -229,6 +413,72 @@ func (c *Client) subscribeQueue(sub *queueSubscription) error {
return nil
}
func (c *Client) subscribeStream(sub *queueSubscription, opts *StreamConsumeOptions) error {
ch, err := c.channel()
if err != nil {
return err
}
args := amqp091.Table{}
if opts != nil {
if opts.ConsumerGroup != "" {
args["x-consumer-group"] = opts.ConsumerGroup
}
if opts.Offset != "" {
args["x-stream-offset"] = opts.Offset
}
for k, v := range opts.Arguments {
args[k] = v
}
}
autoAck := false
exclusive := false
noLocal := false
noWait := false
if opts != nil {
autoAck = opts.AutoAck
exclusive = opts.Exclusive
noLocal = opts.NoLocal
noWait = opts.NoWait
}
c.chMu.Lock()
deliveries, err := ch.Consume(
sub.queueTopic,
sub.consumerTag,
autoAck,
exclusive,
noLocal,
noWait,
args,
)
c.chMu.Unlock()
if err != nil {
return err
}
go func() {
for {
select {
case <-sub.done:
return
case d, ok := <-deliveries:
if !ok {
return
}
sub.handler(&QueueMessage{
Delivery: d,
queueName: sub.queueName,
client: c,
})
}
}
}()
return nil
}
func normalizeQueueTopic(queueName string) string {
if strings.HasPrefix(queueName, "$queue/") {
return queueName
@@ -267,3 +517,99 @@ func applyProperties(p *amqp091.Publishing, props map[string]string) {
}
}
}
func headerUint64(headers amqp091.Table, key string) (uint64, bool) {
if headers == nil {
return 0, false
}
val, ok := headers[key]
if !ok {
return 0, false
}
switch v := val.(type) {
case uint64:
return v, true
case uint32:
return uint64(v), true
case int64:
if v < 0 {
return 0, false
}
return uint64(v), true
case int:
if v < 0 {
return 0, false
}
return uint64(v), true
case string:
if n, err := strconv.ParseUint(v, 10, 64); err == nil {
return n, true
}
}
return 0, false
}
func headerInt64(headers amqp091.Table, key string) (int64, bool) {
if headers == nil {
return 0, false
}
val, ok := headers[key]
if !ok {
return 0, false
}
switch v := val.(type) {
case int64:
return v, true
case int:
return int64(v), true
case uint64:
if v > math.MaxInt64 {
return 0, false
}
return int64(v), true
case string:
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
return n, true
}
}
return 0, false
}
func headerBool(headers amqp091.Table, key string) (bool, bool) {
if headers == nil {
return false, false
}
val, ok := headers[key]
if !ok {
return false, false
}
switch v := val.(type) {
case bool:
return v, true
case string:
if v == "true" {
return true, true
}
if v == "false" {
return false, true
}
}
return false, false
}
func headerString(headers amqp091.Table, key string) (string, bool) {
if headers == nil {
return "", false
}
val, ok := headers[key]
if !ok {
return "", false
}
switch v := val.(type) {
case string:
return v, true
case []byte:
return string(v), true
}
return "", false
}
+40
View File
@@ -0,0 +1,40 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package amqp091
import (
"testing"
amqp091 "github.com/rabbitmq/amqp091-go"
)
func TestQueueMessageStreamMetadata(t *testing.T) {
msg := &QueueMessage{
Delivery: amqp091.Delivery{
Headers: amqp091.Table{
"x-stream-offset": "42",
"x-stream-timestamp": "1700000000000",
"x-work-acked": "true",
"x-work-committed-offset": "100",
"x-work-group": "workers",
},
},
}
if off, ok := msg.StreamOffset(); !ok || off != 42 {
t.Fatalf("expected stream offset 42, got %d (ok=%v)", off, ok)
}
if ts, ok := msg.StreamTimestamp(); !ok || ts != 1700000000000 {
t.Fatalf("expected stream timestamp, got %d (ok=%v)", ts, ok)
}
if acked, ok := msg.WorkAcked(); !ok || !acked {
t.Fatalf("expected work acked true, got %v (ok=%v)", acked, ok)
}
if committed, ok := msg.WorkCommittedOffset(); !ok || committed != 100 {
t.Fatalf("expected committed offset 100, got %d (ok=%v)", committed, ok)
}
if group, ok := msg.WorkGroup(); !ok || group != "workers" {
t.Fatalf("expected work group workers, got %q (ok=%v)", group, ok)
}
}
+1
View File
@@ -18,6 +18,7 @@ type QueueConsumerInfo struct {
ConsumerID string // Consumer identifier (usually client ID)
ClientID string // MQTT client ID
Pattern string // Subscription pattern within the queue
Mode string // Consumer group mode (queue or stream)
ProxyNodeID string // Node where the consumer is connected
RegisteredAt time.Time
}
+10 -3
View File
@@ -19,10 +19,10 @@ import (
amqp091broker "github.com/absmach/fluxmq/amqp091/broker"
"github.com/absmach/fluxmq/broker/webhook"
"github.com/absmach/fluxmq/cluster"
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
"github.com/absmach/fluxmq/config"
logStorage "github.com/absmach/fluxmq/logstorage"
"github.com/absmach/fluxmq/mqtt/broker"
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
mqtttls "github.com/absmach/fluxmq/pkg/tls"
"github.com/absmach/fluxmq/queue"
"github.com/absmach/fluxmq/queue/raft"
@@ -47,8 +47,8 @@ import (
// messageDispatcher routes cluster-delivered messages to the appropriate protocol broker.
type messageDispatcher struct {
mqtt cluster.MessageHandler
amqp *amqpbroker.Broker
mqtt cluster.MessageHandler
amqp *amqpbroker.Broker
amqp091 *amqp091broker.Broker
}
@@ -339,6 +339,8 @@ func main() {
Name: qc.Name,
Topics: qc.Topics,
Reserved: qc.Reserved,
Type: queueTypes.QueueType(qc.Type),
PrimaryGroup: qc.PrimaryGroup,
MaxMessageSize: qc.Limits.MaxMessageSize,
MaxDepth: qc.Limits.MaxDepth,
MessageTTL: qc.Limits.MessageTTL,
@@ -348,6 +350,11 @@ func main() {
Multiplier: qc.Retry.Multiplier,
DLQEnabled: qc.DLQ.Enabled,
DLQTopic: qc.DLQ.Topic,
Retention: queueTypes.RetentionPolicy{
RetentionTime: qc.Retention.MaxAge,
RetentionBytes: qc.Retention.MaxLengthBytes,
RetentionMessages: qc.Retention.MaxLengthMessages,
},
}))
}
+25 -15
View File
@@ -15,25 +15,28 @@ import (
// Config holds all configuration for the MQTT broker.
type Config struct {
Server ServerConfig `yaml:"server"`
Broker BrokerConfig `yaml:"broker"`
Session SessionConfig `yaml:"session"`
Log LogConfig `yaml:"log"`
Storage StorageConfig `yaml:"storage"`
Cluster ClusterConfig `yaml:"cluster"`
Webhook WebhookConfig `yaml:"webhook"`
RateLimit RateLimitConfig `yaml:"ratelimit"`
Queues []QueueConfig `yaml:"queues"`
Server ServerConfig `yaml:"server"`
Broker BrokerConfig `yaml:"broker"`
Session SessionConfig `yaml:"session"`
Log LogConfig `yaml:"log"`
Storage StorageConfig `yaml:"storage"`
Cluster ClusterConfig `yaml:"cluster"`
Webhook WebhookConfig `yaml:"webhook"`
RateLimit RateLimitConfig `yaml:"ratelimit"`
Queues []QueueConfig `yaml:"queues"`
}
// QueueConfig defines configuration for a persistent queue.
type QueueConfig struct {
Name string `yaml:"name"`
Topics []string `yaml:"topics"`
Reserved bool `yaml:"reserved"`
Limits QueueLimits `yaml:"limits"`
Retry QueueRetry `yaml:"retry"`
DLQ QueueDLQ `yaml:"dlq"`
Name string `yaml:"name"`
Topics []string `yaml:"topics"`
Reserved bool `yaml:"reserved"`
Type string `yaml:"type"`
PrimaryGroup string `yaml:"primary_group"`
Retention QueueRetention `yaml:"retention"`
Limits QueueLimits `yaml:"limits"`
Retry QueueRetry `yaml:"retry"`
DLQ QueueDLQ `yaml:"dlq"`
}
// QueueLimits defines resource limits for a queue.
@@ -57,6 +60,13 @@ type QueueDLQ struct {
Topic string `yaml:"topic"`
}
// QueueRetention defines retention policy for a queue.
type QueueRetention struct {
MaxAge time.Duration `yaml:"max_age"`
MaxLengthBytes int64 `yaml:"max_length_bytes"`
MaxLengthMessages int64 `yaml:"max_length_messages"`
}
// RateLimitConfig holds rate limiting configuration.
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
+48
View File
@@ -520,6 +520,54 @@ func main() {
- `SubscribeToQueue` passes the consumer group via `x-consumer-group` on `basic.consume`.
- `Ack`, `Nack`, and `Reject` map to `basic.ack`, `basic.nack`, and `basic.reject`.
### Stream Queues (RabbitMQ-Compatible)
Stream queues provide log-style consumption with cursor offsets.
Stream queue names follow RabbitMQ conventions (no `$queue/` prefix).
Supported offsets: `first`, `last`, `next`, `offset=<n>`, `timestamp=<unix>`.
```go
// Declare a stream queue
qName, err := c.DeclareStreamQueue(&amqp091.StreamQueueOptions{
Name: "events",
Durable: true,
MaxAge: "7D",
MaxLengthBytes: 10 * 1024 * 1024 * 1024,
})
if err != nil {
log.Fatal(err)
}
log.Printf("stream queue: %s", qName)
// Consume from the beginning
err = c.SubscribeToStream(&amqp091.StreamConsumeOptions{
QueueName: "events",
Offset: "first",
}, func(msg *amqp091.QueueMessage) {
if off, ok := msg.StreamOffset(); ok {
log.Printf("offset=%d payload=%s", off, string(msg.Body))
}
_ = msg.Ack()
})
if err != nil {
log.Fatal(err)
}
// Publish to the stream queue (RabbitMQ-style)
if err := c.PublishToStream("events", []byte("hello"), nil); err != nil {
log.Fatal(err)
}
```
Stream deliveries include:
- `x-stream-offset`
- `x-stream-timestamp`
- `x-work-acked` / `x-work-committed-offset`
The `x-work-*` fields report the configured primary work groups committed offset.
Convenience accessors are available on `QueueMessage`:
`StreamOffset()`, `StreamTimestamp()`, `WorkAcked()`, `WorkCommittedOffset()`, `WorkGroup()`.
### Pub/Sub
```go
+23
View File
@@ -525,6 +525,12 @@ queues:
- "orders/#"
- "$queue/orders/#" # allow explicit queue publish
reserved: false
type: "stream" # classic|stream (optional)
primary_group: "workers"
retention:
max_age: "168h"
max_length_bytes: 10737418240
max_length_messages: 1000000
limits:
max_message_size: 1048576
max_depth: 100000
@@ -554,6 +560,23 @@ queue, include `$queue/<name>/#` in the list.
Marks a queue as system-reserved (cannot be deleted via management APIs).
### type
Queue behavior mode:
- `classic` (default): work-queue semantics (acks drive committed offset)
- `stream`: append-only log semantics with cursor-based consumption
### primary_group
Consumer group name used to compute delivery status for stream consumers.
### retention
Retention policy for stream-style access:
- `max_age`: time-based retention window
- `max_length_bytes`: size-based retention window
- `max_length_messages`: message-count window
### limits / retry / dlq
Per-queue limits, retry policy, and dead-letter behavior. If `dlq.topic` is
+65 -4
View File
@@ -35,6 +35,7 @@ Queues provide durable, at-least-once delivery across protocols:
- Redelivery via visibility timeouts and work stealing
- DLQ handler exists but is not wired into the main delivery path
- FIFO order per queue and per consumer group (single cursor)
- Stream queues (RabbitMQ-compatible) for event-log consumption with cursor offsets
Queues are integrated with MQTT and AMQP:
- MQTT uses `$queue/<name>/...` topics
@@ -78,6 +79,14 @@ $queue/tasks/image-processing/$reject → Reject (DLQ wiring is planned)
$dlq/{queue-name} → Dead-letter queue
```
Stream queues use RabbitMQ-compatible queue names and arguments:
- Declare with `x-queue-type=stream`
- Consume with `x-stream-offset`
- Retention via `x-max-age`, `x-max-length-bytes`, `x-max-length`
The `$queue/<name>` prefix remains supported for legacy queue clients.
**Acknowledgment requirements**:
- `message-id` and `group-id` must be provided in message properties.
- For MQTT, these are MQTT v5 User Properties.
@@ -90,6 +99,13 @@ Queue deliveries include properties:
- `queue`
- `offset`
Stream deliveries also include:
- `x-stream-offset`
- `x-stream-timestamp`
- `x-work-acked` (based on the primary work groups committed offset)
- `x-work-committed-offset`
- `x-work-group`
### Message Flow (Queue Publish)
```
@@ -119,6 +135,46 @@ but DLQ moves are not currently triggered automatically.
---
## RabbitMQ Stream Compatibility
FluxMQ supports RabbitMQ-style stream queues at the protocol level:
- `queue.declare` with `x-queue-type=stream`
- `basic.consume` with `x-stream-offset`
- `x-max-age`, `x-max-length-bytes`, `x-max-length` for retention
Stream queues are append-only: acks do **not** delete messages. Messages are
removed only when retention policies allow it, and only up to the safe
truncation point for queue-mode consumers.
`x-stream-offset` supports:
- `first`
- `last`
- `next`
- `offset=<n>`
- `timestamp=<unix-seconds|unix-millis>`
FluxMQ extensions for stream consumers:
- `x-work-acked` and `x-work-committed-offset` to report delivery status for the
configured primary work group.
- `x-work-group` to identify the group used for status.
- Optional `x-consumer-group` on `basic.consume` to persist a shared cursor.
If omitted, the consumer tag is used as the stream group ID.
Primary work group is configured per queue (see configuration section) and is
used only for delivery status reporting; it does not affect routing.
## Model Alignment
FluxMQs queue model aligns with:
- **Kafka**: append-only log + consumer groups + retention.
- **Pulsar**: subscriptions + retention window for replay.
- **NATS JetStream**: queue vs log semantics are defined by consumer mode.
Stream queues provide log-like semantics, while classic queue groups preserve
work-queue behavior.
## Log Storage Engine
Queues are backed by an **append-only log** with segments and sparse indexes.
@@ -151,10 +207,15 @@ Each queue has a single log (no partitions).
### Retention
The queue manager periodically truncates logs to the **minimum committed offset**
across all consumer groups. Time-based or size-based retention policies exist in
`logstorage` (segment manager retention) but are not wired into the runtime
loop or exposed via the main config.
The queue manager truncates logs to a **safe offset** that respects:
- The minimum committed offset across **queue-mode** consumer groups
- The queues retention policy (time/size/message-count)
This means:
- Queue-mode consumers never lose unacked data.
- Stream consumers do not block truncation.
- Retention keeps data available for event-log readers as configured.
---
+3 -2
View File
@@ -11,7 +11,8 @@ import (
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
"github.com/absmach/fluxmq/server/tcp"
"github.com/absmach/fluxmq/storage/badger"
"github.com/absmach/fluxmq/testutil"
@@ -40,7 +41,7 @@ func startBroker(t *testing.T, dataDir string, tcpPort int) *testBrokerInstance
require.NoError(t, err)
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
b := broker.NewBroker(store, nil, nullLogger, nil, nil, nil, nil)
b := broker.NewBroker(store, nil, nullLogger, nil, nil, nil, nil, config.SessionConfig{})
tcpAddr := fmt.Sprintf("127.0.0.1:%d", tcpPort)
tcpCfg := tcp.Config{
+10
View File
@@ -100,6 +100,16 @@ func (a *Adapter) Store() *Store {
return a.store
}
// OffsetByTime returns the offset for the given timestamp.
func (a *Adapter) OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
return a.store.LookupByTime(queueName, ts)
}
// OffsetBySize returns the offset to keep when enforcing size retention.
func (a *Adapter) OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
return a.store.RetentionOffsetBySize(queueName, retentionBytes)
}
// QueueStore interface implementation
// CreateQueue creates a new queue with the given configuration.
+6
View File
@@ -81,6 +81,9 @@ func (s *ConsumerGroupStateStore) loadAll() error {
if state.Cursor == nil {
state.Cursor = &types.QueueCursor{}
}
if state.Mode == "" {
state.Mode = types.GroupModeQueue
}
if state.PEL == nil {
state.PEL = make(map[string][]*types.PendingEntry)
}
@@ -143,6 +146,9 @@ func (s *ConsumerGroupStateStore) loadGroup(queueName, groupID string) (*types.C
if state.Cursor == nil {
state.Cursor = &types.QueueCursor{}
}
if state.Mode == "" {
state.Mode = types.GroupModeQueue
}
if state.PEL == nil {
state.PEL = make(map[string][]*types.PendingEntry)
}
+34
View File
@@ -443,6 +443,40 @@ func (m *SegmentManager) ApplyRetention() error {
return nil
}
// RetentionOffsetBySize returns the offset to keep when enforcing size retention.
// It uses segment granularity (does not split segments).
func (m *SegmentManager) RetentionOffsetBySize(retentionBytes int64) (uint64, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if retentionBytes <= 0 {
return m.headOffset, nil
}
var totalSize int64
for _, seg := range m.segments {
totalSize += seg.Size()
}
if totalSize <= retentionBytes {
return m.headOffset, nil
}
sizeToTrim := totalSize - retentionBytes
var trimmed int64
for _, seg := range m.segments {
segSize := seg.Size()
if trimmed+segSize < sizeToTrim {
trimmed += segSize
continue
}
return seg.BaseOffset(), nil
}
return m.headOffset, nil
}
// applyTimeRetention removes segments older than the cutoff time.
func (m *SegmentManager) applyTimeRetention(cutoff time.Time) error {
// Keep at least one segment (the active one)
+20
View File
@@ -244,6 +244,26 @@ func (s *Store) ReadRange(queueName string, startOffset, endOffset uint64, maxMe
return manager.ReadRange(startOffset, endOffset, maxMessages)
}
// LookupByTime finds the offset for the given timestamp.
func (s *Store) LookupByTime(queueName string, ts time.Time) (uint64, error) {
manager, err := s.getQueue(queueName)
if err != nil {
return 0, err
}
return manager.LookupByTime(ts)
}
// RetentionOffsetBySize returns the offset to keep when enforcing size retention.
func (s *Store) RetentionOffsetBySize(queueName string, retentionBytes int64) (uint64, error) {
manager, err := s.getQueue(queueName)
if err != nil {
return 0, err
}
return manager.RetentionOffsetBySize(retentionBytes)
}
// Head returns the head offset for a queue.
func (s *Store) Head(queueName string) (uint64, error) {
manager, err := s.getQueue(queueName)
+109 -1
View File
@@ -21,6 +21,7 @@ var (
ErrConsumerNotFound = errors.New("consumer not found")
ErrMessageNotPending = errors.New("message not in pending list")
ErrInvalidOffset = errors.New("invalid offset")
ErrGroupModeMismatch = errors.New("consumer group mode mismatch")
)
// Manager handles consumer group operations including claiming,
@@ -68,10 +69,21 @@ func NewManager(queueStore storage.QueueStore, groupStore storage.ConsumerGroupS
}
// GetOrCreateGroup retrieves or creates a consumer group.
func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, pattern string) (*types.ConsumerGroupState, error) {
func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, pattern string, mode types.ConsumerGroupMode) (*types.ConsumerGroupState, error) {
// Try to get existing group
group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID)
if err == nil {
if mode == "" {
return group, nil
}
if group.Mode == "" {
group.Mode = mode
_ = m.groupStore.UpdateConsumerGroup(ctx, group)
return group, nil
}
if group.Mode != mode {
return nil, ErrGroupModeMismatch
}
return group, nil
}
@@ -82,6 +94,9 @@ func (m *Manager) GetOrCreateGroup(ctx context.Context, queueName, groupID, patt
// Create new group
group = types.NewConsumerGroupState(queueName, groupID, pattern)
if mode != "" {
group.Mode = mode
}
if err := m.groupStore.CreateConsumerGroup(ctx, group); err != nil {
// Handle race condition - another process might have created it
@@ -180,6 +195,70 @@ func (m *Manager) ClaimBatch(ctx context.Context, queueName, groupID, consumerID
return messages, nil
}
// ClaimBatchStream retrieves multiple messages for a stream consumer without PEL tracking.
// It advances the cursor once per batch for efficiency.
func (m *Manager) ClaimBatchStream(ctx context.Context, queueName, groupID, consumerID string, filter *Filter, limit int) ([]*types.Message, error) {
m.mu.Lock()
defer m.mu.Unlock()
if limit <= 0 {
limit = m.config.ClaimBatchSize
}
group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID)
if err != nil {
return nil, err
}
cursor := group.GetCursor()
tail, err := m.queueStore.Tail(ctx, group.QueueName)
if err != nil {
return nil, err
}
var messages []*types.Message
var newCursor uint64 = cursor.Cursor
for newCursor < tail && len(messages) < limit {
offset := newCursor
newCursor++
msg, err := m.queueStore.Read(ctx, group.QueueName, offset)
if err != nil {
if err == storage.ErrOffsetOutOfRange {
continue
}
return nil, err
}
if filter != nil {
queueRoot := "$queue/" + group.QueueName
routingKey := types.ExtractRoutingKey(msg.Topic, queueRoot)
if !filter.Matches(routingKey) {
continue
}
}
messages = append(messages, msg)
}
if len(messages) == 0 {
return nil, ErrNoMessages
}
if newCursor > cursor.Cursor {
if err := m.groupStore.UpdateCursor(ctx, group.QueueName, group.ID, newCursor); err != nil {
return nil, err
}
// Keep committed in sync for stream groups.
if err := m.groupStore.UpdateCommitted(ctx, group.QueueName, group.ID, newCursor); err != nil {
return nil, err
}
}
return messages, nil
}
// claimFromCursor tries to claim a message from the cursor position.
func (m *Manager) claimFromCursor(ctx context.Context, group *types.ConsumerGroupState, consumerID string, filter *Filter) (*types.Message, error) {
cursor := group.GetCursor()
@@ -465,6 +544,35 @@ func (m *Manager) GetMinCommittedOffset(ctx context.Context, queueName string) (
return minCommitted, nil
}
// GetMinCommittedOffsetByMode returns the minimum committed offset for groups matching mode.
// If no groups of that mode exist, returns the queue tail.
func (m *Manager) GetMinCommittedOffsetByMode(ctx context.Context, queueName string, mode types.ConsumerGroupMode) (uint64, error) {
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
if err != nil {
return 0, err
}
var minCommitted uint64
first := true
for _, group := range groups {
if mode != "" && group.Mode != mode {
continue
}
cursor := group.GetCursor()
if first || cursor.Committed < minCommitted {
minCommitted = cursor.Committed
first = false
}
}
if first {
return m.queueStore.Tail(ctx, queueName)
}
return minCommitted, nil
}
// UpdateHeartbeat updates the heartbeat timestamp for a consumer.
func (m *Manager) UpdateHeartbeat(ctx context.Context, queueName, groupID, consumerID string) error {
m.mu.Lock()
+267 -17
View File
@@ -7,6 +7,8 @@ import (
"context"
"fmt"
"log/slog"
"strconv"
"strings"
"sync"
"time"
@@ -239,6 +241,11 @@ func (m *Manager) CreateQueue(ctx context.Context, config types.QueueConfig) err
return nil
}
// UpdateQueue updates an existing queue.
func (m *Manager) UpdateQueue(ctx context.Context, config types.QueueConfig) error {
return m.queueStore.UpdateQueue(ctx, config)
}
// GetOrCreateQueue gets or creates a queue with default configuration.
func (m *Manager) GetOrCreateQueue(ctx context.Context, queueName string, topics ...string) (*types.QueueConfig, error) {
// Try to get existing
@@ -421,19 +428,34 @@ func (m *Manager) Enqueue(ctx context.Context, topic string, payload []byte, pro
// SubscribeWithCursor adds a consumer with explicit cursor positioning.
func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern string, clientID, groupID, proxyNodeID string, cursor *types.CursorOption) error {
mode := types.GroupModeQueue
if cursor != nil && cursor.Mode != "" {
mode = cursor.Mode
}
if cursor == nil || cursor.Position == types.CursorDefault {
return m.Subscribe(ctx, queueName, pattern, clientID, groupID, proxyNodeID)
if mode != types.GroupModeStream {
return m.Subscribe(ctx, queueName, pattern, clientID, groupID, proxyNodeID)
}
cursor = &types.CursorOption{Position: types.CursorLatest, Mode: mode}
}
// Ensure queue exists
queueTopicPattern := "$queue/" + queueName + "/#"
_, err := m.GetOrCreateQueue(ctx, queueName, queueTopicPattern)
queueCfg, err := m.GetOrCreateQueue(ctx, queueName, queueTopicPattern)
if err != nil {
return fmt.Errorf("failed to get or create queue: %w", err)
}
if mode == types.GroupModeStream && queueCfg != nil && queueCfg.Type != types.QueueTypeStream {
queueCfg.Type = types.QueueTypeStream
_ = m.queueStore.UpdateQueue(ctx, *queueCfg)
}
if groupID == "" {
groupID = extractGroupFromClientID(clientID)
if mode == types.GroupModeStream {
groupID = clientID
} else {
groupID = extractGroupFromClientID(clientID)
}
}
patternGroupID := groupID
@@ -441,7 +463,7 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
patternGroupID = fmt.Sprintf("%s@%s", groupID, pattern)
}
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern)
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern, mode)
if err != nil {
return err
}
@@ -469,6 +491,12 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
offset = tail
}
m.groupStore.UpdateCursor(ctx, queueName, group.ID, offset)
case types.CursorTimestamp:
if !cursor.Timestamp.IsZero() {
if offset, err := m.offsetByTime(ctx, queueName, cursor.Timestamp); err == nil {
m.groupStore.UpdateCursor(ctx, queueName, group.ID, offset)
}
}
}
if err := m.consumerManager.RegisterConsumer(ctx, queueName, group.ID, clientID, clientID, proxyNodeID); err != nil {
@@ -485,6 +513,7 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
ConsumerID: clientID,
ClientID: clientID,
Pattern: pattern,
Mode: string(mode),
ProxyNodeID: proxyNodeID,
RegisteredAt: time.Now(),
}
@@ -501,7 +530,8 @@ func (m *Manager) SubscribeWithCursor(ctx context.Context, queueName, pattern st
slog.String("queue", queueName),
slog.String("group", patternGroupID),
slog.String("client", clientID),
slog.String("cursor", fmt.Sprintf("%d", cursor.Position)))
slog.String("cursor", fmt.Sprintf("%d", cursor.Position)),
slog.String("mode", string(mode)))
return nil
}
@@ -528,7 +558,7 @@ func (m *Manager) Subscribe(ctx context.Context, queueName, pattern string, clie
}
// Get or create consumer group
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern)
group, err := m.consumerManager.GetOrCreateGroup(ctx, queueName, patternGroupID, pattern, types.GroupModeQueue)
if err != nil {
return err
}
@@ -549,6 +579,7 @@ func (m *Manager) Subscribe(ctx context.Context, queueName, pattern string, clie
ConsumerID: clientID,
ClientID: clientID,
Pattern: pattern,
Mode: string(types.GroupModeQueue),
ProxyNodeID: proxyNodeID,
RegisteredAt: time.Now(),
}
@@ -622,6 +653,20 @@ func (m *Manager) Ack(ctx context.Context, queueName, messageID, groupID string)
return err
}
if groupID != "" {
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
if group.Mode == types.GroupModeStream {
cursor := group.GetCursor()
next := offset + 1
if next > cursor.Cursor {
_ = m.groupStore.UpdateCursor(ctx, queueName, group.ID, next)
_ = m.groupStore.UpdateCommitted(ctx, queueName, group.ID, next)
}
return nil
}
}
}
// Find the consumer that has this message pending
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
if err != nil {
@@ -633,6 +678,15 @@ func (m *Manager) Ack(ctx context.Context, queueName, messageID, groupID string)
if groupID != "" && group.ID != groupID {
continue
}
if group.Mode == types.GroupModeStream {
cursor := group.GetCursor()
next := offset + 1
if next > cursor.Cursor {
_ = m.groupStore.UpdateCursor(ctx, queueName, group.ID, next)
_ = m.groupStore.UpdateCommitted(ctx, queueName, group.ID, next)
}
return nil
}
// Find and ack the message
for consumerID := range group.PEL {
@@ -654,6 +708,14 @@ func (m *Manager) Nack(ctx context.Context, queueName, messageID, groupID string
return err
}
if groupID != "" {
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
if group.Mode == types.GroupModeStream {
return nil
}
}
}
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
if err != nil {
return err
@@ -663,6 +725,9 @@ func (m *Manager) Nack(ctx context.Context, queueName, messageID, groupID string
if groupID != "" && group.ID != groupID {
continue
}
if group.Mode == types.GroupModeStream {
return nil
}
for consumerID := range group.PEL {
err := m.consumerManager.Nack(ctx, queueName, group.ID, consumerID, offset)
@@ -683,6 +748,14 @@ func (m *Manager) Reject(ctx context.Context, queueName, messageID, groupID, rea
return err
}
if groupID != "" {
if group, err := m.groupStore.GetConsumerGroup(ctx, queueName, groupID); err == nil {
if group.Mode == types.GroupModeStream {
return nil
}
}
}
groups, err := m.groupStore.ListConsumerGroups(ctx, queueName)
if err != nil {
return err
@@ -692,6 +765,9 @@ func (m *Manager) Reject(ctx context.Context, queueName, messageID, groupID, rea
if groupID != "" && group.ID != groupID {
continue
}
if group.Mode == types.GroupModeStream {
return nil
}
for consumerID := range group.PEL {
err := m.consumerManager.Reject(ctx, queueName, group.ID, consumerID, offset, reason)
@@ -761,6 +837,31 @@ func (m *Manager) deliverMessages() {
}
for _, queueConfig := range queues {
primaryGroup := strings.TrimSpace(queueConfig.PrimaryGroup)
primaryCommitted := make(map[string]uint64)
getPrimaryCommitted := func(pattern string) (uint64, bool) {
if primaryGroup == "" {
return 0, false
}
patternGroupID := primaryGroup
if pattern != "" {
patternGroupID = fmt.Sprintf("%s@%s", primaryGroup, pattern)
}
if val, ok := primaryCommitted[patternGroupID]; ok {
return val, true
}
committed, err := m.consumerManager.GetCommittedOffset(ctx, queueConfig.Name, patternGroupID)
if err != nil {
return 0, false
}
primaryCommitted[patternGroupID] = committed
return committed, true
}
// Deliver to local consumer groups
groups, err := m.groupStore.ListConsumerGroups(ctx, queueConfig.Name)
if err != nil {
@@ -768,7 +869,7 @@ func (m *Manager) deliverMessages() {
}
for _, group := range groups {
m.deliverToGroup(ctx, &queueConfig, group)
m.deliverToGroup(ctx, &queueConfig, group, getPrimaryCommitted)
}
// Also deliver to remote consumers registered in cluster
@@ -801,8 +902,12 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
}
for groupID, groupConsumers := range consumersByGroup {
mode := types.GroupModeQueue
if groupConsumers[0].Mode != "" {
mode = types.ConsumerGroupMode(groupConsumers[0].Mode)
}
// Get or create a local consumer group state for tracking cursor
group, err := m.consumerManager.GetOrCreateGroup(ctx, config.Name, groupID, groupConsumers[0].Pattern)
group, err := m.consumerManager.GetOrCreateGroup(ctx, config.Name, groupID, groupConsumers[0].Pattern, mode)
if err != nil {
continue
}
@@ -814,23 +919,62 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
}
// Round-robin across remote consumers in this group
var workCommitted uint64
var hasWorkCommitted bool
if group.Mode == types.GroupModeStream && config.PrimaryGroup != "" {
patternGroupID := config.PrimaryGroup
if group.Pattern != "" {
patternGroupID = fmt.Sprintf("%s@%s", config.PrimaryGroup, group.Pattern)
}
if committed, err := m.consumerManager.GetCommittedOffset(ctx, config.Name, patternGroupID); err == nil {
workCommitted = committed
hasWorkCommitted = true
}
}
for _, consumerInfo := range groupConsumers {
// Claim messages for this remote consumer
msgs, err := m.consumerManager.ClaimBatch(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
var msgs []*types.Message
var err error
if group.Mode == types.GroupModeStream {
msgs, err = m.consumerManager.ClaimBatchStream(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
} else {
msgs, err = m.consumerManager.ClaimBatch(ctx, config.Name, groupID, consumerInfo.ConsumerID, filter, m.config.DeliveryBatchSize)
}
if err != nil {
continue
}
// Route each message to the remote node
for _, msg := range msgs {
payload := msg.GetPayload()
properties := msg.Properties
if group.Mode == types.GroupModeStream {
propsCopy := make(map[string]string, len(msg.Properties)+4)
for k, v := range msg.Properties {
propsCopy[k] = v
}
// Decorate properties for stream consumers.
propsCopy["x-stream-offset"] = fmt.Sprintf("%d", msg.Sequence)
if !msg.CreatedAt.IsZero() {
propsCopy["x-stream-timestamp"] = fmt.Sprintf("%d", msg.CreatedAt.UnixMilli())
}
if hasWorkCommitted {
propsCopy["x-work-committed-offset"] = fmt.Sprintf("%d", workCommitted)
propsCopy["x-work-acked"] = strconv.FormatBool(msg.Sequence < workCommitted)
propsCopy["x-work-group"] = config.PrimaryGroup
}
properties = propsCopy
}
err := m.cluster.RouteQueueMessage(
ctx,
consumerInfo.ProxyNodeID,
consumerInfo.ClientID,
config.Name,
msg.ID,
msg.GetPayload(),
msg.Properties,
payload,
properties,
int64(msg.Sequence),
)
if err != nil {
@@ -851,7 +995,7 @@ func (m *Manager) deliverToRemoteConsumers(ctx context.Context, config *types.Qu
}
}
func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig, group *types.ConsumerGroupState) {
func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig, group *types.ConsumerGroupState, primaryCommitted func(pattern string) (uint64, bool)) {
if group.ConsumerCount() == 0 {
return
}
@@ -869,7 +1013,13 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
}
for _, consumerID := range consumers {
msgs, err := m.consumerManager.ClaimBatch(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
var msgs []*types.Message
var err error
if group.Mode == types.GroupModeStream {
msgs, err = m.consumerManager.ClaimBatchStream(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
} else {
msgs, err = m.consumerManager.ClaimBatch(ctx, config.Name, group.ID, consumerID, filter, m.config.DeliveryBatchSize)
}
if err != nil {
continue
}
@@ -890,6 +1040,12 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
m.consumerManager.UpdateHeartbeat(ctx, config.Name, group.ID, consumerID)
}
var workCommitted uint64
var hasWorkCommitted bool
if group.Mode == types.GroupModeStream && primaryCommitted != nil {
workCommitted, hasWorkCommitted = primaryCommitted(group.Pattern)
}
for _, msg := range msgs {
// Check if consumer is on a remote node
if m.cluster != nil && consumerInfo.ProxyNodeID != "" && consumerInfo.ProxyNodeID != m.localNodeID {
@@ -914,6 +1070,9 @@ func (m *Manager) deliverToGroup(ctx context.Context, config *types.QueueConfig,
} else if m.deliverFn != nil {
// Local delivery
deliveryMsg := m.createDeliveryMessage(msg, group.ID, config.Name)
if group.Mode == types.GroupModeStream {
m.decorateStreamDelivery(deliveryMsg, msg, group, workCommitted, hasWorkCommitted, config.PrimaryGroup)
}
if err := m.deliverFn(ctx, consumerInfo.ClientID, deliveryMsg); err != nil {
m.logger.Warn("queue message delivery failed",
@@ -1009,14 +1168,21 @@ func (m *Manager) processRetention() {
}
for _, queueConfig := range queues {
// Get minimum committed offset across all groups
minCommitted, err := m.consumerManager.GetMinCommittedOffset(ctx, queueConfig.Name)
// Get minimum committed offset across queue-mode groups
minCommitted, err := m.consumerManager.GetMinCommittedOffsetByMode(ctx, queueConfig.Name, types.GroupModeQueue)
if err != nil {
continue
}
// Truncate log up to committed offset
if err := m.queueStore.Truncate(ctx, queueConfig.Name, minCommitted); err != nil {
truncateOffset := minCommitted
if retentionOffset, hasRetention := m.computeRetentionOffset(ctx, &queueConfig); hasRetention {
if retentionOffset < truncateOffset {
truncateOffset = retentionOffset
}
}
// Truncate log up to the safe offset
if err := m.queueStore.Truncate(ctx, queueConfig.Name, truncateOffset); err != nil {
m.logger.Debug("truncation error",
slog.String("error", err.Error()),
slog.String("queue", queueConfig.Name))
@@ -1193,6 +1359,28 @@ func (m *Manager) createDeliveryMessage(msg *types.Message, groupID string, queu
return deliveryMsg
}
func (m *Manager) decorateStreamDelivery(delivery *brokerstorage.Message, msg *types.Message, group *types.ConsumerGroupState, workCommitted uint64, hasWorkCommitted bool, primaryGroup string) {
if delivery == nil || msg == nil {
return
}
if delivery.Properties == nil {
delivery.Properties = make(map[string]string)
}
delivery.Properties["x-stream-offset"] = fmt.Sprintf("%d", msg.Sequence)
if !msg.CreatedAt.IsZero() {
delivery.Properties["x-stream-timestamp"] = fmt.Sprintf("%d", msg.CreatedAt.UnixMilli())
}
if hasWorkCommitted {
delivery.Properties["x-work-committed-offset"] = fmt.Sprintf("%d", workCommitted)
delivery.Properties["x-work-acked"] = strconv.FormatBool(msg.Sequence < workCommitted)
if primaryGroup != "" {
delivery.Properties["x-work-group"] = primaryGroup
}
}
}
// DeliveryMessage is the internal message format for queue delivery tracking.
type DeliveryMessage struct {
ID string
@@ -1235,6 +1423,68 @@ func parseMessageID(messageID string) (uint64, error) {
return offset, err
}
func (m *Manager) offsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
if provider, ok := m.queueStore.(storage.TimeOffsetProvider); ok {
return provider.OffsetByTime(ctx, queueName, ts)
}
return m.queueStore.Head(ctx, queueName)
}
func (m *Manager) offsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
if provider, ok := m.queueStore.(storage.SizeOffsetProvider); ok {
return provider.OffsetBySize(ctx, queueName, retentionBytes)
}
return m.queueStore.Head(ctx, queueName)
}
func (m *Manager) computeRetentionOffset(ctx context.Context, config *types.QueueConfig) (uint64, bool) {
if config == nil {
return 0, false
}
var offset uint64
hasRetention := false
if config.Retention.RetentionTime > 0 {
cutoff := time.Now().Add(-config.Retention.RetentionTime)
if off, err := m.offsetByTime(ctx, config.Name, cutoff); err == nil {
if off > offset {
offset = off
}
hasRetention = true
}
}
if config.Retention.RetentionBytes > 0 {
if off, err := m.offsetBySize(ctx, config.Name, config.Retention.RetentionBytes); err == nil {
if off > offset {
offset = off
}
hasRetention = true
}
}
if config.Retention.RetentionMessages > 0 {
head, err := m.queueStore.Head(ctx, config.Name)
if err == nil {
tail, err := m.queueStore.Tail(ctx, config.Name)
if err == nil {
if tail > head+uint64(config.Retention.RetentionMessages) {
msgOffset := tail - uint64(config.Retention.RetentionMessages)
if msgOffset > offset {
offset = msgOffset
}
} else if head > offset {
offset = head
}
hasRetention = true
}
}
}
return offset, hasRetention
}
func parseSubscriptionFilter(filter string) (queueName, pattern string) {
for i, c := range filter {
if c == '/' {
+114
View File
@@ -5,6 +5,7 @@ package queue
import (
"context"
"io"
"log/slog"
"sync"
"testing"
@@ -329,6 +330,119 @@ func TestWildcardQueueSubscription(t *testing.T) {
}
}
func TestStreamGroupDeliversWithoutPEL(t *testing.T) {
logStore := memlog.New()
groupStore := newMockGroupStore()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
delivered := make(chan *brokerstorage.Message, 1)
deliverFn := func(ctx context.Context, clientID string, msg any) error {
if m, ok := msg.(*brokerstorage.Message); ok {
delivered <- m
}
return nil
}
cfg := DefaultConfig()
cfg.DeliveryBatchSize = 10
mgr := NewManager(logStore, groupStore, deliverFn, cfg, logger, nil)
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
queueCfg.Type = types.QueueTypeStream
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
t.Fatalf("CreateQueue failed: %v", err)
}
cursor := &types.CursorOption{Position: types.CursorEarliest, Mode: types.GroupModeStream}
if err := mgr.SubscribeWithCursor(context.Background(), "events", "", "client-1", "streamer", "", cursor); err != nil {
t.Fatalf("SubscribeWithCursor failed: %v", err)
}
if err := mgr.Publish(context.Background(), "$queue/events/test", []byte("hello"), nil); err != nil {
t.Fatalf("Publish failed: %v", err)
}
mgr.deliverMessages()
select {
case msg := <-delivered:
if got := msg.Properties["x-stream-offset"]; got != "0" {
t.Fatalf("expected stream offset 0, got %q", got)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for delivery")
}
group, err := groupStore.GetConsumerGroup(context.Background(), "events", "streamer")
if err != nil {
t.Fatalf("GetConsumerGroup failed: %v", err)
}
if count := group.PendingCount(); count != 0 {
t.Fatalf("expected no pending entries, got %d", count)
}
if cursor := group.GetCursor().Cursor; cursor != 1 {
t.Fatalf("expected cursor 1, got %d", cursor)
}
}
func TestStreamAckAdvancesCursor(t *testing.T) {
logStore := memlog.New()
groupStore := newMockGroupStore()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
mgr := NewManager(logStore, groupStore, nil, DefaultConfig(), logger, nil)
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
queueCfg.Type = types.QueueTypeStream
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
t.Fatalf("CreateQueue failed: %v", err)
}
cursor := &types.CursorOption{Position: types.CursorEarliest, Mode: types.GroupModeStream}
if err := mgr.SubscribeWithCursor(context.Background(), "events", "", "client-1", "streamer", "", cursor); err != nil {
t.Fatalf("SubscribeWithCursor failed: %v", err)
}
if err := mgr.Ack(context.Background(), "events", "events:0", "streamer"); err != nil {
t.Fatalf("Ack failed: %v", err)
}
group, err := groupStore.GetConsumerGroup(context.Background(), "events", "streamer")
if err != nil {
t.Fatalf("GetConsumerGroup failed: %v", err)
}
if cursor := group.GetCursor().Cursor; cursor != 1 {
t.Fatalf("expected cursor 1, got %d", cursor)
}
}
func TestRetentionOffsetMessages(t *testing.T) {
logStore := memlog.New()
groupStore := newMockGroupStore()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
mgr := NewManager(logStore, groupStore, nil, DefaultConfig(), logger, nil)
queueCfg := types.DefaultQueueConfig("events", "$queue/events/#")
if err := mgr.CreateQueue(context.Background(), queueCfg); err != nil {
t.Fatalf("CreateQueue failed: %v", err)
}
for i := 0; i < 3; i++ {
if err := mgr.Publish(context.Background(), "$queue/events/test", []byte("msg"), nil); err != nil {
t.Fatalf("Publish failed: %v", err)
}
}
queueCfg.Retention.RetentionMessages = 1
queueCfg.Name = "events"
offset, ok := mgr.computeRetentionOffset(context.Background(), &queueCfg)
if !ok {
t.Fatal("expected retention offset")
}
if offset != 2 {
t.Fatalf("expected retention offset 2, got %d", offset)
}
}
func TestExactQueueSubscription(t *testing.T) {
logStore := memlog.New()
groupStore := newMockGroupStore()
+48
View File
@@ -279,6 +279,54 @@ func (s *Store) Tail(ctx context.Context, queueName string) (uint64, error) {
return sl.tail, nil
}
// OffsetByTime returns the offset for the given timestamp.
func (s *Store) OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error) {
sl, err := s.getQueueLog(queueName)
if err != nil {
return 0, err
}
sl.mu.RLock()
defer sl.mu.RUnlock()
if len(sl.messages) == 0 {
return sl.head, nil
}
for i, msg := range sl.messages {
if !msg.CreatedAt.IsZero() && !msg.CreatedAt.Before(ts) {
return sl.head + uint64(i), nil
}
}
return sl.tail, nil
}
// OffsetBySize returns the offset to keep when enforcing size retention.
func (s *Store) OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error) {
sl, err := s.getQueueLog(queueName)
if err != nil {
return 0, err
}
sl.mu.RLock()
defer sl.mu.RUnlock()
if retentionBytes <= 0 || len(sl.messages) == 0 {
return sl.head, nil
}
var total int64
for i := len(sl.messages) - 1; i >= 0; i-- {
total += int64(len(sl.messages[i].GetPayload()))
if total > retentionBytes {
return sl.head + uint64(i+1), nil
}
}
return sl.head, nil
}
// Truncate removes all messages with offset < minOffset.
func (s *Store) Truncate(ctx context.Context, queueName string, minOffset uint64) error {
sl, err := s.getQueueLog(queueName)
+19
View File
@@ -0,0 +1,19 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package storage
import (
"context"
"time"
)
// TimeOffsetProvider exposes time-based offset lookups.
type TimeOffsetProvider interface {
OffsetByTime(ctx context.Context, queueName string, ts time.Time) (uint64, error)
}
// SizeOffsetProvider exposes size-based retention offsets.
type SizeOffsetProvider interface {
OffsetBySize(ctx context.Context, queueName string, retentionBytes int64) (uint64, error)
}
+30 -1
View File
@@ -15,16 +15,21 @@ var ErrInvalidConfig = errors.New("invalid queue configuration")
type CursorPosition int
const (
CursorDefault CursorPosition = iota // resume from stored position
CursorDefault CursorPosition = iota // resume from stored position
CursorEarliest // start from beginning
CursorLatest // start from end
CursorOffset // start from specific offset
CursorTimestamp // start from a timestamp
)
// CursorOption specifies cursor positioning for SubscribeWithCursor.
type CursorOption struct {
Position CursorPosition
Offset uint64 // only used when Position == CursorOffset
// Timestamp is used when Position == CursorTimestamp.
Timestamp time.Time
// Mode defines the consumer group mode (queue or stream).
Mode ConsumerGroupMode
}
// Queue constants
@@ -46,6 +51,11 @@ type QueueConfig struct {
Name string
Topics []string // Topic patterns that route to this queue (e.g., "sensors/#", "orders/+/created")
Reserved bool // True for system queues like "mqtt" that cannot be deleted
Type QueueType
// PrimaryGroup defines the consumer group whose committed offset is used
// to report delivery status to stream consumers.
PrimaryGroup string
// Durability
Durable bool // true = persists indefinitely, false = ephemeral (cleaned up when no consumers remain)
@@ -68,6 +78,14 @@ type QueueConfig struct {
HeartbeatTimeout time.Duration
}
// QueueType defines the queue behavior mode.
type QueueType string
const (
QueueTypeClassic QueueType = "classic"
QueueTypeStream QueueType = "stream"
)
// RetryPolicy defines retry behavior for failed messages.
type RetryPolicy struct {
MaxRetries int
@@ -134,6 +152,7 @@ func DefaultQueueConfig(name string, topics ...string) QueueConfig {
Name: name,
Topics: topics,
Reserved: false,
Type: QueueTypeClassic,
Durable: true,
MaxMessageSize: 10 * 1024 * 1024, // 10MB
MaxDepth: 100000,
@@ -178,6 +197,8 @@ type QueueConfigInput struct {
Name string
Topics []string
Reserved bool
Type QueueType
PrimaryGroup string
MaxMessageSize int64
MaxDepth int64
MessageTTL time.Duration
@@ -187,6 +208,7 @@ type QueueConfigInput struct {
Multiplier float64
DLQEnabled bool
DLQTopic string
Retention RetentionPolicy
}
// FromInput creates a QueueConfig from a simplified input config.
@@ -194,6 +216,11 @@ func FromInput(input QueueConfigInput) QueueConfig {
cfg := DefaultQueueConfig(input.Name, input.Topics...)
cfg.Durable = true
cfg.Reserved = input.Reserved
if input.Type != "" {
cfg.Type = input.Type
}
cfg.PrimaryGroup = input.PrimaryGroup
cfg.Retention = input.Retention
if input.MaxMessageSize > 0 {
cfg.MaxMessageSize = input.MaxMessageSize
@@ -254,6 +281,8 @@ func (c *QueueConfig) Validate() error {
return ErrInvalidConfig
case c.Reserved && !c.Durable:
return ErrInvalidConfig
case c.Type != "" && c.Type != QueueTypeClassic && c.Type != QueueTypeStream:
return ErrInvalidConfig
}
// Validate replication config if enabled
+10
View File
@@ -43,6 +43,7 @@ type ConsumerGroupState struct {
ID string // Group identifier
QueueName string // Queue this group consumes from
Pattern string // Subscription pattern (e.g., "sensors/#")
Mode ConsumerGroupMode
// Queue cursor state (single cursor per queue, no partitions)
Cursor *QueueCursor
@@ -59,6 +60,14 @@ type ConsumerGroupState struct {
UpdatedAt time.Time
}
// ConsumerGroupMode defines how a consumer group is tracked.
type ConsumerGroupMode string
const (
GroupModeQueue ConsumerGroupMode = "queue"
GroupModeStream ConsumerGroupMode = "stream"
)
// NewConsumerGroupState creates a new consumer group state.
func NewConsumerGroupState(queueName, groupID, pattern string) *ConsumerGroupState {
now := time.Now()
@@ -66,6 +75,7 @@ func NewConsumerGroupState(queueName, groupID, pattern string) *ConsumerGroupSta
ID: groupID,
QueueName: queueName,
Pattern: pattern,
Mode: GroupModeQueue,
Cursor: &QueueCursor{
Cursor: 0,
Committed: 0,
+6
View File
@@ -145,6 +145,12 @@ func (s *Server) handlePublish(w mux.ResponseWriter, r *mux.Message) {
return
}
// if !strings.HasPrefix(path, "/mqtt/publish/") {
// s.logger.Warn("coap_publish_invalid_path", slog.String("path", path))
// s.sendResponse(w, r, codes.BadRequest, "invalid path")
// return
// }
topic := strings.TrimPrefix(path, "/mqtt/publish/")
if topic == "" {
s.logger.Warn("coap_publish_missing_topic")
+173 -50
View File
@@ -4,22 +4,124 @@
package coap
import (
"bytes"
"context"
"io"
"log/slog"
"net"
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/cluster"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
"github.com/absmach/fluxmq/storage/memory"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
"github.com/plgd-dev/go-coap/v3/mux"
)
type stubConn struct {
last *pool.Message
done chan struct{}
}
func newStubConn() *stubConn {
return &stubConn{done: make(chan struct{})}
}
func (c *stubConn) AcquireMessage(ctx context.Context) *pool.Message {
return pool.NewMessage(ctx)
}
func (c *stubConn) ReleaseMessage(*pool.Message) {}
func (c *stubConn) Ping(ctx context.Context) error { return nil }
func (c *stubConn) Get(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) Delete(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) Post(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) Put(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) Observe(ctx context.Context, path string, observeFunc func(notification *pool.Message), opts ...message.Option) (mux.Observation, error) {
return nil, nil
}
func (c *stubConn) RemoteAddr() net.Addr { return stubAddr("coap-stub") }
func (c *stubConn) NetConn() net.Conn { return nil }
func (c *stubConn) Context() context.Context {
return context.Background()
}
func (c *stubConn) SetContextValue(key interface{}, val interface{}) {}
func (c *stubConn) WriteMessage(req *pool.Message) error {
c.last = req
return nil
}
func (c *stubConn) Do(req *pool.Message) (*pool.Message, error) { return nil, nil }
func (c *stubConn) DoObserve(req *pool.Message, observeFunc func(req *pool.Message)) (mux.Observation, error) {
return nil, nil
}
func (c *stubConn) Close() error {
close(c.done)
return nil
}
func (c *stubConn) Sequence() uint64 { return 0 }
func (c *stubConn) Done() <-chan struct{} {
return c.done
}
func (c *stubConn) AddOnClose(func()) {}
func (c *stubConn) NewGetRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) NewObserveRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) NewPutRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) NewPostRequest(ctx context.Context, path string, contentFormat message.MediaType, payload io.ReadSeeker, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
func (c *stubConn) NewDeleteRequest(ctx context.Context, path string, opts ...message.Option) (*pool.Message, error) {
return nil, nil
}
type stubResponseWriter struct {
conn *stubConn
msg *pool.Message
}
func (w *stubResponseWriter) SetResponse(code codes.Code, contentFormat message.MediaType, d io.ReadSeeker, opts ...message.Option) error {
resp := w.conn.AcquireMessage(context.Background())
resp.SetCode(code)
resp.SetBody(d)
w.conn.last = resp
return nil
}
func (w *stubResponseWriter) Conn() mux.Conn { return w.conn }
func (w *stubResponseWriter) SetMessage(m *pool.Message) {
w.msg = m
}
func (w *stubResponseWriter) Message() *pool.Message { return w.msg }
type stubAddr string
func (a stubAddr) Network() string { return "stub" }
func (a stubAddr) String() string { return string(a) }
func TestNew(t *testing.T) {
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
stats := broker.NewStats()
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
@@ -45,8 +147,7 @@ func TestNew(t *testing.T) {
func TestConfig_TLSConfig(t *testing.T) {
cfg := Config{
Address: ":5684",
ShutdownTimeout: 5 * time.Second,
Address: ":5684",
TLSConfig: &piondtls.Config{
ClientAuth: piondtls.RequireAndVerifyClientCert,
},
@@ -60,63 +161,85 @@ func TestConfig_TLSConfig(t *testing.T) {
}
}
func TestServer_ListenUDP_ContextCancel(t *testing.T) {
func TestHandleHealth(t *testing.T) {
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
stats := broker.NewStats()
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: ":15683", // Use non-standard port for testing
ShutdownTimeout: 1 * time.Second,
server := New(Config{}, b, slog.Default())
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
req := &mux.Message{Message: pool.NewMessage(context.Background())}
server.handleHealth(writer, req)
if conn.last == nil {
t.Fatal("expected response message")
}
if conn.last.Code() != codes.Content {
t.Fatalf("expected code %v, got %v", codes.Content, conn.last.Code())
}
body, err := conn.last.ReadBody()
if err != nil {
t.Fatalf("failed to read body: %v", err)
}
if string(body) != "healthy" {
t.Fatalf("expected body %q, got %q", "healthy", string(body))
}
}
server := New(cfg, b, slog.Default())
func TestHandlePublish(t *testing.T) {
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
stats := broker.NewStats()
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil, config.SessionConfig{})
defer b.Close()
ctx, cancel := context.WithCancel(context.Background())
server := New(Config{}, b, slog.Default())
done := make(chan error, 1)
go func() {
done <- server.Listen(ctx)
}()
t.Run("missing topic", func(t *testing.T) {
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
// Give server time to start
time.Sleep(100 * time.Millisecond)
reqMsg := pool.NewMessage(context.Background())
reqMsg.MustSetPath("/mqtt/publish")
req := &mux.Message{Message: reqMsg}
// Cancel context to trigger shutdown
cancel()
server.handlePublish(writer, req)
select {
case err := <-done:
if err != nil {
t.Errorf("Expected nil error on shutdown, got: %v", err)
if conn.last == nil {
t.Fatal("expected response message")
}
case <-time.After(5 * time.Second):
t.Error("Server did not shut down in time")
}
}
func TestServer_ListenDTLS_InvalidConfig(t *testing.T) {
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
stats := broker.NewStats()
b := broker.NewBroker(store, cl, slog.Default(), stats, nil, nil, nil)
defer b.Close()
cfg := Config{
Address: ":15684",
ShutdownTimeout: 1 * time.Second,
TLSConfig: &piondtls.Config{},
}
server := New(cfg, b, slog.Default())
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := server.Listen(ctx)
if err == nil {
t.Error("Expected error for invalid DTLS configuration")
}
if conn.last.Code() != codes.BadRequest {
t.Fatalf("expected code %v, got %v", codes.BadRequest, conn.last.Code())
}
})
t.Run("ok", func(t *testing.T) {
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
reqMsg := pool.NewMessage(context.Background())
reqMsg.MustSetPath("/mqtt/publish/test/topic")
reqMsg.SetBody(bytes.NewReader([]byte("payload")))
req := &mux.Message{Message: reqMsg}
server.handlePublish(writer, req)
if conn.last == nil {
t.Fatal("expected response message")
}
if conn.last.Code() != codes.Changed {
t.Fatalf("expected code %v, got %v", codes.Changed, conn.last.Code())
}
body, err := conn.last.ReadBody()
if err != nil {
t.Fatalf("failed to read body: %v", err)
}
if string(body) != "ok" {
t.Fatalf("expected body %q, got %q", "ok", string(body))
}
})
}
+84 -211
View File
@@ -9,11 +9,12 @@ import (
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/cluster"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
clusterv1 "github.com/absmach/fluxmq/pkg/proto/cluster/v1"
"github.com/absmach/fluxmq/storage"
)
@@ -36,22 +37,6 @@ func (m *mockCluster) AcquireSession(ctx context.Context, clientID, nodeID strin
return nil
}
func (m *mockCluster) AcquirePartition(ctx context.Context, queueName string, partitionID int, nodeID string) error {
return nil
}
func (m *mockCluster) ReleasePartition(ctx context.Context, queueName string, partitionID int) error {
return nil
}
func (m *mockCluster) EnqueueRemote(ctx context.Context, nodeID, queueName string, payload []byte, properties map[string]string) (string, error) {
return "", nil
}
func (m *mockCluster) RouteQueueMessage(ctx context.Context, nodeID, clientID, queueName, messageID string, payload []byte, properties map[string]string, sequence int64, partitionID int) error {
return nil
}
func (m *mockCluster) ReleaseSession(ctx context.Context, clientID string) error {
return nil
}
@@ -60,6 +45,10 @@ func (m *mockCluster) GetSessionOwner(ctx context.Context, clientID string) (str
return "", false, nil
}
func (m *mockCluster) WatchSessionOwner(ctx context.Context, clientID string) <-chan cluster.OwnershipChange {
return nil
}
func (m *mockCluster) AddSubscription(ctx context.Context, clientID, filter string, qos byte, opts storage.SubscribeOptions) error {
return nil
}
@@ -76,6 +65,30 @@ func (m *mockCluster) GetSubscribersForTopic(ctx context.Context, topic string)
return nil, nil
}
func (m *mockCluster) RegisterQueueConsumer(ctx context.Context, info *cluster.QueueConsumerInfo) error {
return nil
}
func (m *mockCluster) UnregisterQueueConsumer(ctx context.Context, queueName, groupID, consumerID string) error {
return nil
}
func (m *mockCluster) ListQueueConsumers(ctx context.Context, queueName string) ([]*cluster.QueueConsumerInfo, error) {
return nil, nil
}
func (m *mockCluster) ListQueueConsumersByGroup(ctx context.Context, queueName, groupID string) ([]*cluster.QueueConsumerInfo, error) {
return nil, nil
}
func (m *mockCluster) ListAllQueueConsumers(ctx context.Context) ([]*cluster.QueueConsumerInfo, error) {
return nil, nil
}
func (m *mockCluster) ForwardQueuePublish(ctx context.Context, nodeID, topic string, payload []byte, properties map[string]string) error {
return nil
}
func (m *mockCluster) Retained() storage.RetainedStore {
return nil
}
@@ -92,6 +105,10 @@ func (m *mockCluster) TakeoverSession(ctx context.Context, clientID, fromNode, t
return nil, nil
}
func (m *mockCluster) RouteQueueMessage(ctx context.Context, nodeID, clientID, queueName, messageID string, payload []byte, properties map[string]string, sequence int64) error {
return nil
}
func (m *mockCluster) WaitForLeader(ctx context.Context) error {
return nil
}
@@ -108,57 +125,21 @@ func (m *mockCluster) Nodes() []cluster.NodeInfo {
return nil
}
func TestServerStartStop(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
func TestAddrWithoutListener(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
}
server := New(cfg, b, nil, slog.Default())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
time.Sleep(100 * time.Millisecond)
cancel()
select {
case err := <-errCh:
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("server did not stop in time")
server := New(Config{}, b, nil, slog.Default())
if server.Addr() != "" {
t.Fatalf("expected empty address before listen, got %q", server.Addr())
}
}
func TestHealthEndpoint(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
}
server := New(cfg, b, nil, slog.Default())
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
addr := server.Addr()
server := New(Config{}, b, nil, slog.Default())
tests := []struct {
name string
@@ -186,25 +167,18 @@ func TestHealthEndpoint(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequest(tt.method, "http://"+addr+"/health", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
req := httptest.NewRequest(tt.method, "http://test/health", nil)
rec := httptest.NewRecorder()
client := &http.Client{Timeout: 2 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
defer resp.Body.Close()
server.handleHealth(rec, req)
if resp.StatusCode != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
if rec.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
}
if tt.expectedStatus == http.StatusOK {
var response HealthResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
@@ -237,7 +211,7 @@ func TestReadyEndpoint(t *testing.T) {
},
{
name: "single node mode - ready",
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
cluster: nil,
method: http.MethodGet,
expectedStatus: http.StatusOK,
@@ -245,7 +219,7 @@ func TestReadyEndpoint(t *testing.T) {
},
{
name: "cluster not initialized - not ready",
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
cluster: &mockCluster{nodeID: ""},
method: http.MethodGet,
expectedStatus: http.StatusServiceUnavailable,
@@ -254,7 +228,7 @@ func TestReadyEndpoint(t *testing.T) {
},
{
name: "cluster initialized - ready",
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
cluster: &mockCluster{nodeID: "node-1", isLeader: true},
method: http.MethodGet,
expectedStatus: http.StatusOK,
@@ -262,7 +236,7 @@ func TestReadyEndpoint(t *testing.T) {
},
{
name: "POST request not allowed",
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil),
broker: broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{}),
cluster: nil,
method: http.MethodPost,
expectedStatus: http.StatusMethodNotAllowed,
@@ -275,40 +249,20 @@ func TestReadyEndpoint(t *testing.T) {
defer tt.broker.Close()
}
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
}
server := New(Config{}, tt.broker, tt.cluster, slog.Default())
server := New(cfg, tt.broker, tt.cluster, slog.Default())
req := httptest.NewRequest(tt.method, "http://test/ready", nil)
rec := httptest.NewRecorder()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server.handleReady(rec, req)
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
addr := server.Addr()
req, err := http.NewRequest(tt.method, "http://"+addr+"/ready", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
client := &http.Client{Timeout: 2 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
if rec.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
}
if tt.expectedStatus == http.StatusOK || tt.expectedStatus == http.StatusServiceUnavailable {
var response ReadyResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
@@ -376,38 +330,18 @@ func TestClusterStatusEndpoint(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
}
server := New(Config{}, b, tt.cluster, slog.Default())
server := New(cfg, b, tt.cluster, slog.Default())
req := httptest.NewRequest(tt.method, "http://test/cluster/status", nil)
rec := httptest.NewRecorder()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
server.handleClusterStatus(rec, req)
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
addr := server.Addr()
req, err := http.NewRequest(tt.method, "http://"+addr+"/cluster/status", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
client := &http.Client{Timeout: 2 * time.Second}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode)
if rec.Code != tt.expectedStatus {
t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code)
}
if tt.checkMethodNotAllowed {
@@ -416,7 +350,7 @@ func TestClusterStatusEndpoint(t *testing.T) {
if tt.expectedStatus == http.StatusOK {
var response ClusterStatusResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
if err := json.NewDecoder(rec.Body).Decode(&response); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
@@ -442,41 +376,33 @@ func TestClusterStatusEndpoint(t *testing.T) {
}
func TestContentTypeHeaders(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
server := New(Config{}, b, nil, slog.Default())
tests := []struct {
name string
handler http.HandlerFunc
}{
{name: "/health", handler: server.handleHealth},
{name: "/ready", handler: server.handleReady},
{name: "/cluster/status", handler: server.handleClusterStatus},
}
server := New(cfg, b, nil, slog.Default())
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://test"+tt.name, nil)
rec := httptest.NewRecorder()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tt.handler(rec, req)
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
addr := server.Addr()
endpoints := []string{"/health", "/ready", "/cluster/status"}
for _, endpoint := range endpoints {
t.Run(endpoint, func(t *testing.T) {
resp, err := http.Get("http://" + addr + endpoint)
if err != nil {
t.Fatalf("failed to send request: %v", err)
}
defer resp.Body.Close()
contentType := resp.Header.Get("Content-Type")
contentType := rec.Header().Get("Content-Type")
if contentType != "application/json" {
t.Errorf("expected Content-Type application/json, got %q", contentType)
}
// Verify it's valid JSON
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(rec.Body)
if err != nil {
t.Fatalf("failed to read body: %v", err)
}
@@ -488,56 +414,3 @@ func TestContentTypeHeaders(t *testing.T) {
})
}
}
func TestGracefulShutdown(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 2 * time.Second,
}
server := New(cfg, b, nil, slog.Default())
ctx, cancel := context.WithCancel(context.Background())
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
time.Sleep(100 * time.Millisecond)
addr := server.Addr()
// Make a request while server is running
resp, err := http.Get("http://" + addr + "/health")
if err != nil {
t.Fatalf("failed to get health before shutdown: %v", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status 200, got %d", resp.StatusCode)
}
// Trigger shutdown
cancel()
// Server should stop gracefully
select {
case err := <-errCh:
if err != nil {
t.Logf("shutdown completed with: %v", err)
}
case <-time.After(3 * time.Second):
t.Fatal("server did not stop after shutdown timeout")
}
// After shutdown, requests should fail
_, err = http.Get("http://" + addr + "/health")
if err == nil {
t.Error("expected request to fail after shutdown")
}
}
+145 -122
View File
@@ -7,217 +7,240 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
)
type stubListener struct {
conns chan net.Conn
closed chan struct{}
addr net.Addr
}
func newStubListener() *stubListener {
return &stubListener{
conns: make(chan net.Conn, 16),
closed: make(chan struct{}),
addr: stubAddr("in-memory"),
}
}
func (l *stubListener) Accept() (net.Conn, error) {
select {
case <-l.closed:
return nil, net.ErrClosed
case conn, ok := <-l.conns:
if !ok {
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *stubListener) Close() error {
select {
case <-l.closed:
return nil
default:
close(l.closed)
close(l.conns)
return nil
}
}
func (l *stubListener) Addr() net.Addr { return l.addr }
func (l *stubListener) push(conn net.Conn) error {
select {
case <-l.closed:
return net.ErrClosed
default:
l.conns <- conn
return nil
}
}
type stubAddr string
func (a stubAddr) Network() string { return "stub" }
func (a stubAddr) String() string { return string(a) }
type trackingConn struct {
net.Conn
closed atomic.Bool
}
func (c *trackingConn) Close() error {
c.closed.Store(true)
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}
func TestServerStartStop(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 1 * time.Second,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connCtx, connCancel := context.WithCancel(context.Background())
listener := newStubListener()
// Start server in goroutine
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
server.mu.Lock()
server.listener = listener
server.mu.Unlock()
// Wait a bit for server to start
time.Sleep(100 * time.Millisecond)
// Verify server started
if server.Addr() == nil {
t.Fatal("server address is nil after start")
}
// Cancel context to stop server
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
cancel()
// Wait for server to stop
select {
case err := <-errCh:
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("server did not stop in time")
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
t.Fatalf("unexpected shutdown error: %v", err)
}
}
func TestShutdown(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 5 * time.Second,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
connCtx, connCancel := context.WithCancel(context.Background())
listener := newStubListener()
// Start server
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
server.mu.Lock()
server.listener = listener
server.mu.Unlock()
time.Sleep(100 * time.Millisecond)
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
// Connect a client
conn, err := net.Dial("tcp", server.Addr().String())
if err != nil {
t.Fatalf("failed to connect: %v", err)
serverConn, clientConn := net.Pipe()
if err := listener.push(serverConn); err != nil {
t.Fatalf("failed to push connection: %v", err)
}
clientConn.Close()
// Trigger shutdown
cancel()
// Close client connection
conn.Close()
// Server should stop gracefully
select {
case err := <-errCh:
if err != nil {
t.Logf("shutdown completed with: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("server did not stop after shutdown timeout")
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
t.Fatalf("unexpected shutdown error: %v", err)
}
}
func TestConnectionLimit(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
maxConns := 2
maxConns := 1
cfg := Config{
Address: "localhost:0",
MaxConnections: maxConns,
ShutdownTimeout: 1 * time.Second,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := context.Background()
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
// Create max connections
conns := make([]net.Conn, maxConns)
for i := 0; i < maxConns; i++ {
conn, err := net.Dial("tcp", server.Addr().String())
if err != nil {
t.Fatalf("failed to connect %d: %v", i, err)
}
conns[i] = conn
s1, c1 := net.Pipe()
conn1 := &trackingConn{Conn: s1}
if !server.tryAcquireConnectionSlot(ctx, conn1) {
t.Fatal("expected first connection to be accepted")
}
// Wait for connections to be accepted
time.Sleep(200 * time.Millisecond)
// Try one more connection - should be rejected
extraConn, err := net.DialTimeout("tcp", server.Addr().String(), 500*time.Millisecond)
if err == nil {
extraConn.Close()
// Connection might get through briefly before being rejected
// Wait and check if it stays open
time.Sleep(100 * time.Millisecond)
s2, c2 := net.Pipe()
conn2 := &trackingConn{Conn: s2}
if server.tryAcquireConnectionSlot(ctx, conn2) {
t.Fatal("expected second connection to be rejected")
}
if !conn2.closed.Load() {
t.Fatal("expected rejected connection to be closed")
}
// The extra connection should not be handled (or rejected quickly)
// We can't easily test this without instrumenting the handler more
// Clean up
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
c1.Close()
c2.Close()
server.releaseConnectionSlot()
}
func TestConcurrentConnections(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
ShutdownTimeout: 2 * time.Second,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
connCtx, connCancel := context.WithCancel(context.Background())
listener := newStubListener()
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
server.mu.Lock()
server.listener = listener
server.mu.Unlock()
acceptDone := server.runAcceptLoop(ctx, connCtx, listener)
// Create many concurrent connections
numConns := 50
numConns := 20
var wg sync.WaitGroup
wg.Add(numConns)
for i := 0; i < numConns; i++ {
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", server.Addr().String())
if err != nil {
serverConn, clientConn := net.Pipe()
if err := listener.push(serverConn); err != nil {
return
}
conn.Write([]byte("test"))
conn.Close()
clientConn.Close()
}()
}
wg.Wait()
time.Sleep(500 * time.Millisecond)
// All connections should be handled successfully by the broker
cancel()
if err := server.gracefulShutdown(listener, acceptDone, connCancel); err != nil {
t.Fatalf("unexpected shutdown error: %v", err)
}
}
func TestTCPOptimizations(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil)
func TestDefaultConfigApplied(t *testing.T) {
b := broker.NewBroker(nil, nil, nil, nil, nil, nil, nil, config.SessionConfig{})
defer b.Close()
cfg := Config{
Address: "localhost:0",
TCPKeepAlive: 15 * time.Second,
DisableNoDelay: false,
ShutdownTimeout: 1 * time.Second,
server := New(Config{}, b)
if server.config.ShutdownTimeout == 0 {
t.Fatal("expected default ShutdownTimeout to be set")
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.Listen(ctx)
time.Sleep(100 * time.Millisecond)
// Connect and test
conn, err := net.Dial("tcp", server.Addr().String())
if err != nil {
t.Fatalf("failed to connect: %v", err)
if server.config.ReadTimeout == 0 {
t.Fatal("expected default ReadTimeout to be set")
}
if server.config.WriteTimeout == 0 {
t.Fatal("expected default WriteTimeout to be set")
}
if server.config.IdleTimeout == 0 {
t.Fatal("expected default IdleTimeout to be set")
}
if server.config.BufferSize == 0 {
t.Fatal("expected default BufferSize to be set")
}
if server.config.TCPKeepAlive == 0 {
t.Fatal("expected default TCPKeepAlive to be set")
}
conn.Close()
time.Sleep(200 * time.Millisecond)
// TCP options are applied successfully if connection is accepted
}
+130 -332
View File
@@ -6,63 +6,47 @@ package tcp
import (
"context"
"crypto/tls"
"errors"
"io"
"log/slog"
"net"
"os"
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/cluster"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
v3 "github.com/absmach/fluxmq/mqtt/packets/v3"
"github.com/absmach/fluxmq/storage/memory"
)
func TestTLS_BasicConnection(t *testing.T) {
certs := GenerateTestCerts(t)
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
func newTestBroker(t *testing.T) *broker.Broker {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
b := broker.NewBroker(nil, nil, logger, nil, nil, nil, nil, config.SessionConfig{})
t.Cleanup(func() { _ = b.Close() })
return b
}
// Create server with TLS
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
func runHandleConnection(s *Server, conn net.Conn) {
s.wg.Add(1)
go s.handleConnection(context.Background(), conn)
}
cfg := Config{
Address: "127.0.0.1:0", // Random port
TLSConfig: tlsConfig,
ShutdownTimeout: 5 * time.Second,
Logger: nullLogger,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
func waitForConnections(t *testing.T, s *Server) {
t.Helper()
done := make(chan struct{})
go func() {
errCh <- server.Listen(ctx)
s.wg.Wait()
close(done)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
addr := server.Addr().String()
// Connect with TLS client
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
if err != nil {
t.Fatalf("Failed to connect with TLS: %v", err)
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("connection handler did not exit in time")
}
defer conn.Close()
}
// Verify TLS handshake completed
if err := conn.Handshake(); err != nil {
t.Fatalf("TLS handshake failed: %v", err)
}
// Send MQTT CONNECT packet
func mqttHandshake(t *testing.T, conn net.Conn) {
t.Helper()
connectPkt := &v3.Connect{
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
ProtocolName: "MQTT",
@@ -72,339 +56,153 @@ func TestTLS_BasicConnection(t *testing.T) {
}
if err := connectPkt.Pack(conn); err != nil {
t.Fatalf("Failed to send CONNECT: %v", err)
t.Fatalf("failed to send CONNECT: %v", err)
}
// Read CONNACK
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
connack, err := v3.ReadPacket(conn)
if err != nil {
t.Fatalf("Failed to read CONNACK: %v", err)
t.Fatalf("failed to read CONNACK: %v", err)
}
if connack.Type() != v3.ConnAckType {
t.Fatalf("Expected CONNACK, got %v", connack.Type())
t.Fatalf("expected CONNACK, got %v", connack.Type())
}
// Send DISCONNECT packet
disconnectPkt := &v3.Disconnect{
FixedHeader: v3.FixedHeader{PacketType: v3.DisconnectType},
}
disconnectPkt.Pack(conn)
conn.Close()
_ = disconnectPkt.Pack(conn)
}
// Give broker time to process disconnect
time.Sleep(100 * time.Millisecond)
// Shutdown server
cancel()
func tlsHandshakeWithTimeout(conn *tls.Conn, timeout time.Duration) error {
errCh := make(chan error, 1)
go func() {
errCh <- conn.Handshake()
}()
select {
case err := <-errCh:
if err != nil {
t.Logf("Server shutdown with error: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Server shutdown timeout")
return err
case <-time.After(timeout):
return errors.New("handshake timeout")
}
}
func TestTLS_BasicConnection(t *testing.T) {
certs := GenerateTestCerts(t)
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
clientTLS := LoadClientTLSConfig(t, certs, false)
clientTLS.ServerName = "localhost"
b := newTestBroker(t)
server := New(Config{TLSConfig: serverTLS}, b)
serverConn, clientConn := net.Pipe()
tlsServer := tls.Server(serverConn, serverTLS)
tlsClient := tls.Client(clientConn, clientTLS)
runHandleConnection(server, tlsServer)
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err != nil {
t.Fatalf("TLS handshake failed: %v", err)
}
mqttHandshake(t, tlsClient)
_ = tlsClient.Close()
waitForConnections(t, server)
}
func TestTLS_RequireClientCert(t *testing.T) {
certs := GenerateTestCerts(t)
tlsConfig := LoadServerTLSConfig(t, certs, tls.RequireAndVerifyClientCert)
serverTLS := LoadServerTLSConfig(t, certs, tls.RequireAndVerifyClientCert)
// Verify TLS config is set up correctly
t.Logf("Server TLS ClientAuth: %v (expected: %v)", tlsConfig.ClientAuth, tls.RequireAndVerifyClientCert)
if tlsConfig.ClientAuth != tls.RequireAndVerifyClientCert {
t.Fatalf("Server TLS config ClientAuth not set correctly")
}
// Create server with TLS requiring client cert
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
b := broker.NewBroker(store, cl, logger, nil, nil, nil, nil)
cfg := Config{
Address: "127.0.0.1:0",
TLSConfig: tlsConfig,
ShutdownTimeout: 5 * time.Second,
Logger: logger,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
addr := server.Addr().String()
// Test 1: Connection without client cert should fail
t.Run("NoClientCert", func(t *testing.T) {
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
if err != nil {
// Expected - connection failed during handshake
t.Logf("Connection correctly rejected during dial: %v", err)
return
}
defer conn.Close()
b := newTestBroker(t)
server := New(Config{TLSConfig: serverTLS}, b)
clientTLS := LoadClientTLSConfig(t, certs, false)
clientTLS.ServerName = "localhost"
serverConn, clientConn := net.Pipe()
tlsServer := tls.Server(serverConn, serverTLS)
tlsClient := tls.Client(clientConn, clientTLS)
// Client-side handshake might succeed, but server should close the connection
// Try to read which should fail when server closes due to missing client cert
buf := make([]byte, 1)
conn.SetReadDeadline(time.Now().Add(1 * time.Second))
_, err = conn.Read(buf)
if err != nil {
t.Logf("Connection correctly rejected: %v", err)
return
}
runHandleConnection(server, tlsServer)
t.Fatal("Expected connection to fail without client certificate, but it succeeded")
err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second)
if err == nil {
connectPkt := &v3.Connect{
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
ProtocolName: "MQTT",
ProtocolVersion: 4,
CleanSession: true,
ClientID: "no-cert-client",
}
err = connectPkt.Pack(tlsClient)
if err == nil {
tlsClient.SetReadDeadline(time.Now().Add(1 * time.Second))
_, err = v3.ReadPacket(tlsClient)
}
}
if err == nil {
t.Fatal("expected connection to be rejected without client cert")
}
_ = tlsClient.Close()
waitForConnections(t, server)
})
// Test 2: Connection with client cert should succeed
t.Run("WithClientCert", func(t *testing.T) {
clientTLSConfig := LoadClientTLSConfig(t, certs, true)
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
if err != nil {
t.Fatalf("Failed to connect with client cert: %v", err)
}
defer conn.Close()
b := newTestBroker(t)
server := New(Config{TLSConfig: serverTLS}, b)
clientTLS := LoadClientTLSConfig(t, certs, true)
clientTLS.ServerName = "localhost"
serverConn, clientConn := net.Pipe()
tlsServer := tls.Server(serverConn, serverTLS)
tlsClient := tls.Client(clientConn, clientTLS)
// Verify TLS handshake with client cert
if err := conn.Handshake(); err != nil {
runHandleConnection(server, tlsServer)
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err != nil {
t.Fatalf("TLS handshake failed: %v", err)
}
// Verify connection state has peer certificates
state := conn.ConnectionState()
if len(state.PeerCertificates) == 0 {
t.Fatal("Server did not receive client certificate")
}
mqttHandshake(t, tlsClient)
_ = tlsClient.Close()
waitForConnections(t, server)
})
// Shutdown server
cancel()
select {
case err := <-errCh:
if err != nil {
t.Logf("Server shutdown with error: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Server shutdown timeout")
}
}
func TestTLS_InvalidCert(t *testing.T) {
func TestTLS_UntrustedServer(t *testing.T) {
certs := GenerateTestCerts(t)
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
// Create server
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
b := newTestBroker(t)
server := New(Config{TLSConfig: serverTLS}, b)
cfg := Config{
Address: "127.0.0.1:0",
TLSConfig: tlsConfig,
ShutdownTimeout: 5 * time.Second,
Logger: nullLogger,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
addr := server.Addr().String()
// Try to connect without trusting the server's CA
insecureTLSConfig := &tls.Config{
InsecureSkipVerify: false, // Explicitly don't skip verification
}
conn, err := tls.Dial("tcp", addr, insecureTLSConfig)
if err == nil {
conn.Close()
t.Fatal("Expected connection to fail with unverified certificate")
}
// Shutdown server
cancel()
select {
case err := <-errCh:
if err != nil {
t.Logf("Server shutdown with error: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Server shutdown timeout")
serverConn, clientConn := net.Pipe()
tlsServer := tls.Server(serverConn, serverTLS)
tlsClient := tls.Client(clientConn, &tls.Config{InsecureSkipVerify: false, ServerName: "localhost"})
runHandleConnection(server, tlsServer)
if err := tlsHandshakeWithTimeout(tlsClient, 2*time.Second); err == nil {
t.Fatal("expected TLS handshake to fail for untrusted server cert")
}
_ = tlsClient.Close()
waitForConnections(t, server)
}
func TestTLS_MinVersion(t *testing.T) {
func TestTLS_MinVersionConfig(t *testing.T) {
certs := GenerateTestCerts(t)
tlsConfig := LoadServerTLSConfig(t, certs, tls.NoClientCert)
// Verify minimum TLS version is enforced
if tlsConfig.MinVersion != tls.VersionTLS12 {
t.Fatalf("Expected MinVersion to be TLS 1.2, got %v", tlsConfig.MinVersion)
}
// Create server
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
cfg := Config{
Address: "127.0.0.1:0",
TLSConfig: tlsConfig,
ShutdownTimeout: 5 * time.Second,
Logger: nullLogger,
}
server := New(cfg, b)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
addr := server.Addr().String()
// Try to connect with TLS 1.1 (should fail)
clientTLSConfig := LoadClientTLSConfig(t, certs, false)
clientTLSConfig.MaxVersion = tls.VersionTLS11
conn, err := tls.Dial("tcp", addr, clientTLSConfig)
if err == nil {
conn.Close()
// Note: This might not fail in all Go versions as the client's MaxVersion
// might be overridden. The important thing is the server enforces MinVersion.
t.Log("Note: Client was able to connect with TLS 1.1 (client-side compatibility)")
} else {
t.Logf("Connection correctly rejected with TLS 1.1: %v", err)
}
// Verify TLS 1.2+ works
clientTLSConfig.MaxVersion = tls.VersionTLS13
clientTLSConfig.MinVersion = tls.VersionTLS12
conn, err = tls.Dial("tcp", addr, clientTLSConfig)
if err != nil {
t.Fatalf("Failed to connect with TLS 1.2+: %v", err)
}
conn.Close()
// Shutdown server
cancel()
select {
case err := <-errCh:
if err != nil {
t.Logf("Server shutdown with error: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Server shutdown timeout")
serverTLS := LoadServerTLSConfig(t, certs, tls.NoClientCert)
if serverTLS.MinVersion != tls.VersionTLS12 {
t.Fatalf("expected MinVersion to be TLS 1.2, got %v", serverTLS.MinVersion)
}
}
func TestTLS_NoTLS(t *testing.T) {
// Verify server works without TLS when TLSConfig is nil
store := memory.New()
cl := cluster.NewNoopCluster("test-node")
nullLogger := slog.New(slog.NewTextHandler(os.NewFile(0, os.DevNull), nil))
b := broker.NewBroker(store, cl, nullLogger, nil, nil, nil, nil)
b := newTestBroker(t)
server := New(Config{}, b)
cfg := Config{
Address: "127.0.0.1:0",
TLSConfig: nil, // No TLS
ShutdownTimeout: 5 * time.Second,
Logger: nullLogger,
}
server := New(cfg, b)
serverConn, clientConn := net.Pipe()
runHandleConnection(server, serverConn)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- server.Listen(ctx)
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
addr := server.Addr().String()
// Connect without TLS
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Failed to connect without TLS: %v", err)
}
defer conn.Close()
// Send MQTT CONNECT packet
connectPkt := &v3.Connect{
FixedHeader: v3.FixedHeader{PacketType: v3.ConnectType},
ProtocolName: "MQTT",
ProtocolVersion: 4,
CleanSession: true,
ClientID: "plain-test-client",
}
if err := connectPkt.Pack(conn); err != nil {
t.Fatalf("Failed to send CONNECT: %v", err)
}
// Read CONNACK
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
connack, err := v3.ReadPacket(conn)
if err != nil {
t.Fatalf("Failed to read CONNACK: %v", err)
}
if connack.Type() != v3.ConnAckType {
t.Fatalf("Expected CONNACK, got %v", connack.Type())
}
// Send DISCONNECT packet
disconnectPkt2 := &v3.Disconnect{
FixedHeader: v3.FixedHeader{PacketType: v3.DisconnectType},
}
disconnectPkt2.Pack(conn)
conn.Close()
// Give broker time to process disconnect
time.Sleep(100 * time.Millisecond)
// Shutdown server
cancel()
select {
case err := <-errCh:
if err != nil {
t.Logf("Server shutdown with error: %v", err)
}
case <-time.After(6 * time.Second):
t.Fatal("Server shutdown timeout")
}
mqttHandshake(t, clientConn)
_ = clientConn.Close()
waitForConnections(t, server)
}
+3 -2
View File
@@ -12,8 +12,9 @@ import (
"testing"
"time"
"github.com/absmach/fluxmq/broker"
"github.com/absmach/fluxmq/cluster"
"github.com/absmach/fluxmq/config"
"github.com/absmach/fluxmq/mqtt/broker"
"github.com/absmach/fluxmq/server/tcp"
"github.com/absmach/fluxmq/storage/badger"
"github.com/stretchr/testify/require"
@@ -193,7 +194,7 @@ func (tc *TestCluster) startNode(node *TestNode, bootstrap bool, peerTransports
store.Close()
return fmt.Errorf("failed to create cluster: %w", err)
}
b := broker.NewBroker(store, clust, nullLogger, nil, nil, nil, nil)
b := broker.NewBroker(store, clust, nullLogger, nil, nil, nil, nil, config.SessionConfig{})
// Wire broker as message handler (includes session management)
clust.SetMessageHandler(b)