Fix tests

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
dusan
2026-02-05 12:47:16 +01:00
parent 81b3718ca9
commit 28944064d1
5 changed files with 87 additions and 11 deletions
+6 -1
View File
@@ -52,7 +52,12 @@ run-debug: build
# Run tests
.PHONY: test
test:
$(GO) test -v ./...
$(GO) test -short -race -failfast -timeout 3m -v ./...
# Run full tests (including stress)
.PHONY: test-full
test-full:
$(GO) test -race -count=1 -v -timeout 30m ./...
# Run tests with coverage
.PHONY: test-cover
+3 -2
View File
@@ -93,10 +93,11 @@ func (r *TrieRouter) Match(topic string) ([]*storage.Subscription, error) {
levels := strings.Split(topic, separator)
matched := AcquireSubscriptionSlice()
matchLevel(r.root, levels, 0, matched)
result := *matched
// Copy out before releasing the pooled slice to avoid data races
// when the pool reuses the backing array in other goroutines.
result := append([]*storage.Subscription(nil), (*matched)...)
// Release the pooled slice pointer back to pool
// The slice data itself is now referenced by 'result'
ReleaseSubscriptionSlice(matched)
return result, nil
+30 -7
View File
@@ -54,6 +54,7 @@ func New(cfg Config, b *broker.Broker, logger *slog.Logger) *Server {
s.mux.Handle("/mqtt/publish/{topic}", mux.HandlerFunc(s.handlePublish))
s.mux.Handle("/health", mux.HandlerFunc(s.handleHealth))
s.mux.DefaultHandleFunc(s.handlePublish)
return s
}
@@ -145,19 +146,41 @@ func (s *Server) handlePublish(w mux.ResponseWriter, r *mux.Message) {
return
}
// if !strings.HasPrefix(path, "/mqtt/publish/") {
// s.logger.Warn("coap_publish_invalid_path", slog.String("path", path))
// s.sendResponse(w, r, codes.BadRequest, "invalid path")
// return
// }
// Normalize path to always start with "/"
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
topic := strings.TrimPrefix(path, "/mqtt/publish/")
const (
publishPrefix = "/mqtt/publish/"
reservedMQTT = "/mqtt"
healthPath = "/health"
)
switch {
case path == "/mqtt/publish" || path == publishPrefix:
s.logger.Warn("coap_publish_missing_topic")
s.sendResponse(w, r, codes.BadRequest, "topic is required in path")
return
case strings.HasPrefix(path, publishPrefix):
// Legacy path: /mqtt/publish/<topic>
path = "/" + strings.TrimPrefix(path, publishPrefix)
case path == healthPath:
s.logger.Warn("coap_publish_invalid_path", slog.String("path", path))
s.sendResponse(w, r, codes.BadRequest, "invalid path")
return
case path == reservedMQTT || strings.HasPrefix(path, reservedMQTT+"/"):
s.logger.Warn("coap_publish_invalid_path", slog.String("path", path))
s.sendResponse(w, r, codes.BadRequest, "invalid path")
return
}
topic := strings.TrimPrefix(path, "/")
if topic == "" {
s.logger.Warn("coap_publish_missing_topic")
s.sendResponse(w, r, codes.BadRequest, "topic is required in path")
return
}
payload, err := r.ReadBody()
if err != nil {
s.logger.Warn("coap_publish_read_body_error", slog.String("error", err.Error()))
+38
View File
@@ -217,6 +217,44 @@ func TestHandlePublish(t *testing.T) {
}
})
t.Run("non-mqtt path ok", func(t *testing.T) {
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
reqMsg := pool.NewMessage(context.Background())
reqMsg.MustSetPath("/test/topic")
reqMsg.SetBody(bytes.NewReader([]byte("payload")))
req := &mux.Message{Message: reqMsg}
server.handlePublish(writer, req)
if conn.last == nil {
t.Fatal("expected response message")
}
if conn.last.Code() != codes.Changed {
t.Fatalf("expected code %v, got %v", codes.Changed, conn.last.Code())
}
})
t.Run("reserved mqtt path", func(t *testing.T) {
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
reqMsg := pool.NewMessage(context.Background())
reqMsg.MustSetPath("/mqtt/other")
reqMsg.SetBody(bytes.NewReader([]byte("payload")))
req := &mux.Message{Message: reqMsg}
server.handlePublish(writer, req)
if conn.last == nil {
t.Fatal("expected response message")
}
if conn.last.Code() != codes.BadRequest {
t.Fatalf("expected code %v, got %v", codes.BadRequest, conn.last.Code())
}
})
t.Run("ok", func(t *testing.T) {
conn := newStubConn()
writer := &stubResponseWriter{conn: conn}
+10 -1
View File
@@ -83,6 +83,7 @@ func tlsHandshakeWithTimeout(conn *tls.Conn, timeout time.Duration) error {
case err := <-errCh:
return err
case <-time.After(timeout):
_ = conn.Close()
return errors.New("handshake timeout")
}
}
@@ -120,6 +121,8 @@ func TestTLS_RequireClientCert(t *testing.T) {
clientTLS := LoadClientTLSConfig(t, certs, false)
clientTLS.ServerName = "localhost"
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
tlsServer := tls.Server(serverConn, serverTLS)
tlsClient := tls.Client(clientConn, clientTLS)
@@ -143,7 +146,13 @@ func TestTLS_RequireClientCert(t *testing.T) {
if err == nil {
t.Fatal("expected connection to be rejected without client cert")
}
_ = tlsClient.Close()
// Ensure server side unblocks even if handshake failed mid-flight.
if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
} else {
_ = tlsClient.Close()
}
waitForConnections(t, server)
})