SMQ-2956 - Add routes resolver to abstract topic resolution logic (#2960)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-07-01 16:04:01 +03:00
committed by GitHub
parent 9a194a86df
commit 0d1aafcbbf
22 changed files with 628 additions and 398 deletions
+3 -1
View File
@@ -20,6 +20,7 @@ import (
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/brokers"
brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing"
msgevents "github.com/absmach/supermq/pkg/messaging/events"
@@ -181,7 +182,8 @@ func main() {
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(cfg.InstanceID), logger)
cs := coapserver.NewServer(ctx, cancel, svcName, coapServerConfig, httpapi.MakeCoAPHandler(svc, channelsClient, domainsClient, logger), logger)
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
cs := coapserver.NewServer(ctx, cancel, svcName, coapServerConfig, httpapi.MakeCoAPHandler(svc, channelsClient, resolver, logger), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+2 -1
View File
@@ -221,7 +221,8 @@ func main() {
}
func newService(pub messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) session.Handler {
svc := adapter.NewHandler(pub, authn, clients, channels, domains, logger)
resolver := messaging.NewTopicResolver(channels, domains)
svc := adapter.NewHandler(pub, authn, clients, channels, resolver, logger)
svc = handler.NewTracing(tracer, svc)
svc = handler.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics(svcName, "api")
+5 -59
View File
@@ -23,10 +23,6 @@ import (
"github.com/absmach/mgate/pkg/mqtt/websocket"
"github.com/absmach/mgate/pkg/session"
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
api "github.com/absmach/supermq/api/http"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/mqtt"
"github.com/absmach/supermq/mqtt/events"
@@ -57,11 +53,6 @@ const (
wsPathPrefix = "/mqtt"
)
var (
errFailedResolveDomain = errors.New("failed to resolve domain route")
errFailedResolveChannel = errors.New("failed to resolve channel route")
)
type config struct {
LogLevel string `env:"SMQ_MQTT_ADAPTER_LOG_LEVEL" envDefault:"info"`
MQTTPort string `env:"SMQ_MQTT_ADAPTER_MQTT_PORT" envDefault:"1883"`
@@ -247,8 +238,7 @@ func main() {
}
beforeHandler := beforeHandler{
domains: domainsClient,
channels: channelsClient,
resolver: messaging.NewTopicResolver(channelsClient, domainsClient),
}
afterHandler := afterHandler{
@@ -380,8 +370,7 @@ func (ah afterHandler) Intercept(ctx context.Context, pkt packets.ControlPacket,
}
type beforeHandler struct {
domains grpcDomainsV1.DomainsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
resolver messaging.TopicResolver
}
// This interceptor is used to replace domain and channel routes with relevant domain and channel IDs in the message topic.
@@ -389,7 +378,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
switch pt := pkt.(type) {
case *packets.SubscribePacket:
for i, topic := range pt.Topics {
ft, err := bh.resolveTopic(ctx, topic)
ft, err := bh.resolver.ResolveTopic(ctx, topic)
if err != nil {
return nil, err
}
@@ -399,7 +388,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
return pt, nil
case *packets.UnsubscribePacket:
for i, topic := range pt.Topics {
ft, err := bh.resolveTopic(ctx, topic)
ft, err := bh.resolver.ResolveTopic(ctx, topic)
if err != nil {
return nil, err
}
@@ -407,7 +396,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
}
return pt, nil
case *packets.PublishPacket:
ft, err := bh.resolveTopic(ctx, pt.TopicName)
ft, err := bh.resolver.ResolveTopic(ctx, pt.TopicName)
if err != nil {
return nil, err
}
@@ -418,46 +407,3 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
return pkt, nil
}
func (bh beforeHandler) resolveTopic(ctx context.Context, topic string) (string, error) {
matches := messaging.TopicRegExp.FindStringSubmatch(topic)
if len(matches) < 4 {
return "", messaging.ErrMalformedTopic
}
domainID, err := bh.resolveDomain(ctx, matches[1])
if err != nil {
return "", errors.Wrap(errFailedResolveDomain, err)
}
channelID, err := bh.resolveChannel(ctx, matches[2], domainID)
if err != nil {
return "", errors.Wrap(errFailedResolveChannel, err)
}
return fmt.Sprintf("m/%s/c/%s%s", domainID, channelID, matches[3]), nil
}
func (bh beforeHandler) resolveDomain(ctx context.Context, domain string) (string, error) {
if api.ValidateUUID(domain) == nil {
return domain, nil
}
resp, err := bh.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{Route: domain})
if err != nil {
return "", err
}
return resp.Entity.Id, nil
}
func (bh beforeHandler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if api.ValidateUUID(channel) == nil {
return channel, nil
}
resp, err := bh.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return resp.Entity.Id, nil
}
+6 -6
View File
@@ -19,7 +19,6 @@ import (
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/authn/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
@@ -194,10 +193,11 @@ func main() {
exitCode = 1
return
}
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
svc := newService(clientsClient, channelsClient, domainsClient, nps, logger, tracer)
svc := newService(clientsClient, channelsClient, nps, logger, tracer)
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
@@ -209,7 +209,7 @@ func main() {
})
g.Go(func() error {
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, domainsClient)
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, resolver)
return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler)
})
@@ -222,8 +222,8 @@ func main() {
}
}
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
svc := ws.New(clientsClient, channels, domains, nps)
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
svc := ws.New(clientsClient, channels, nps)
svc = tracing.New(tracer, svc)
svc = httpapi.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics("ws_adapter", "api")
+36 -74
View File
@@ -14,9 +14,6 @@ import (
"github.com/absmach/supermq"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
api "github.com/absmach/supermq/api/http"
"github.com/absmach/supermq/coap"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -36,17 +33,8 @@ const (
)
var (
errBadOptions = errors.New("bad options")
errMethodNotAllowed = errors.New("method not allowed")
errFailedResolveDomain = errors.New("failed to resolve domain route")
errFailedResolveChannel = errors.New("failed to resolve channel route")
)
var (
logger *slog.Logger
service coap.Service
channels grpcChannelsV1.ChannelsServiceClient
domains grpcDomainsV1.DomainsServiceClient
errBadOptions = errors.New("bad options")
errMethodNotAllowed = errors.New("method not allowed")
)
// MakeHandler returns a HTTP handler for API endpoints.
@@ -58,39 +46,47 @@ func MakeHandler(instanceID string) http.Handler {
return b
}
// MakeCoAPHandler creates handler for CoAP messages.
func MakeCoAPHandler(svc coap.Service, channelsClient grpcChannelsV1.ChannelsServiceClient, domainsClient grpcDomainsV1.DomainsServiceClient, l *slog.Logger) mux.HandlerFunc {
logger = l
service = svc
channels = channelsClient
domains = domainsClient
return handler
type CoAPHandler struct {
logger *slog.Logger
service coap.Service
channels grpcChannelsV1.ChannelsServiceClient
resolver messaging.TopicResolver
}
func sendResp(w mux.ResponseWriter, resp *pool.Message) {
// MakeCoAPHandler creates handler for CoAP messages.
func MakeCoAPHandler(svc coap.Service, channelsClient grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver, l *slog.Logger) mux.HandlerFunc {
h := &CoAPHandler{
logger: l,
service: svc,
channels: channelsClient,
resolver: resolver,
}
return h.handler
}
func (h *CoAPHandler) sendResp(w mux.ResponseWriter, resp *pool.Message) {
if err := w.Conn().WriteMessage(resp); err != nil {
logger.Warn(fmt.Sprintf("Can't set response: %s", err))
h.logger.Warn(fmt.Sprintf("Can't set response: %s", err))
}
}
func handler(w mux.ResponseWriter, m *mux.Message) {
func (h *CoAPHandler) handler(w mux.ResponseWriter, m *mux.Message) {
resp := pool.NewMessage(w.Conn().Context())
resp.SetToken(m.Token())
for _, opt := range m.Options() {
resp.AddOptionBytes(opt.ID, opt.Value)
}
defer sendResp(w, resp)
defer h.sendResp(w, resp)
msg, err := decodeMessage(m)
msg, err := h.decodeMessage(m)
if err != nil {
logger.Warn(fmt.Sprintf("Error decoding message: %s", err))
h.logger.Warn(fmt.Sprintf("Error decoding message: %s", err))
resp.SetCode(codes.BadRequest)
return
}
key, err := parseKey(m)
if err != nil {
logger.Warn(fmt.Sprintf("Error parsing auth: %s", err))
h.logger.Warn(fmt.Sprintf("Error parsing auth: %s", err))
resp.SetCode(codes.Unauthorized)
return
}
@@ -98,10 +94,10 @@ func handler(w mux.ResponseWriter, m *mux.Message) {
switch m.Code() {
case codes.GET:
resp.SetCode(codes.Content)
err = handleGet(m, w, msg, key)
err = h.handleGet(m, w, msg, key)
case codes.POST:
resp.SetCode(codes.Created)
err = service.Publish(m.Context(), key, msg)
err = h.service.Publish(m.Context(), key, msg)
default:
err = errMethodNotAllowed
}
@@ -122,24 +118,24 @@ func handler(w mux.ResponseWriter, m *mux.Message) {
}
}
func handleGet(m *mux.Message, w mux.ResponseWriter, msg *messaging.Message, key string) error {
func (h *CoAPHandler) handleGet(m *mux.Message, w mux.ResponseWriter, msg *messaging.Message, key string) error {
var obs uint32
obs, err := m.Options().Observe()
if err != nil {
logger.Warn(fmt.Sprintf("Error reading observe option: %s", err))
h.logger.Warn(fmt.Sprintf("Error reading observe option: %s", err))
return errBadOptions
}
if obs == startObserve {
c := coap.NewClient(w.Conn(), m.Token(), logger)
c := coap.NewClient(w.Conn(), m.Token(), h.logger)
w.Conn().AddOnClose(func() {
_ = service.DisconnectHandler(context.Background(), msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c.Token())
_ = h.service.DisconnectHandler(context.Background(), msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c.Token())
})
return service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c)
return h.service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c)
}
return service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
return h.service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
}
func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
func (h *CoAPHandler) decodeMessage(msg *mux.Message) (*messaging.Message, error) {
if msg.Options() == nil {
return &messaging.Message{}, errBadOptions
}
@@ -159,14 +155,9 @@ func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
return &messaging.Message{}, err
}
domainID, err := resolveDomain(msg.Context(), domain)
domainID, channelID, err := h.resolver.Resolve(msg.Context(), domain, channel)
if err != nil {
return &messaging.Message{}, errors.Wrap(errFailedResolveDomain, err)
}
channelID, err := resolveChannel(msg.Context(), channel, domainID)
if err != nil {
return &messaging.Message{}, errors.Wrap(errFailedResolveChannel, err)
return &messaging.Message{}, err
}
ret := &messaging.Message{
@@ -199,32 +190,3 @@ func parseKey(msg *mux.Message) (string, error) {
}
return vars[1], nil
}
func resolveDomain(ctx context.Context, domain string) (string, error) {
if api.ValidateUUID(domain) == nil {
return domain, nil
}
d, err := domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: domain,
})
if err != nil {
return "", err
}
return d.Entity.Id, nil
}
func resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if api.ValidateUUID(channel) == nil {
return channel, nil
}
c, err := channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return c.Entity.Id, nil
}
+2 -1
View File
@@ -51,7 +51,8 @@ var (
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub) {
pub := new(pubsub.PubSub)
return server.NewHandler(pub, authn, clients, channels, domains, smqlog.NewMock()), pub
resolver := messaging.NewTopicResolver(channels, domains)
return server.NewHandler(pub, authn, clients, channels, resolver, smqlog.NewMock()), pub
}
func newTargetHTTPServer() *httptest.Server {
+5 -43
View File
@@ -15,9 +15,6 @@ import (
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
@@ -52,8 +49,6 @@ var (
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
errFailedResolveDomain = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to resolve domain route"))
errFailedResolveChannel = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to resolve channel route"))
)
// Event implements events.Event interface.
@@ -61,19 +56,19 @@ type handler struct {
publisher messaging.Publisher
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
domains grpcDomainsV1.DomainsServiceClient
resolver messaging.TopicResolver
authn smqauthn.Authentication
logger *slog.Logger
}
// NewHandler creates new Handler entity.
func NewHandler(publisher messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger) session.Handler {
func NewHandler(publisher messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver, logger *slog.Logger) session.Handler {
return &handler{
publisher: publisher,
authn: authn,
clients: clients,
channels: channels,
domains: domains,
resolver: resolver,
logger: logger,
}
}
@@ -130,13 +125,9 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
if err != nil {
return errors.Wrap(errMalformedTopic, err)
}
domainID, err := h.resolveDomain(ctx, domain)
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
if err != nil {
return errors.Wrap(errFailedResolveDomain, err)
}
channelID, err := h.resolveChannel(ctx, channel, domainID)
if err != nil {
return errors.Wrap(errFailedResolveChannel, err)
return errors.Wrap(errFailedPublish, err)
}
var clientID, clientType string
@@ -218,32 +209,3 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) resolveDomain(ctx context.Context, domain string) (string, error) {
if api.ValidateUUID(domain) == nil {
return domain, nil
}
d, err := h.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: domain,
})
if err != nil {
return "", err
}
return d.Entity.Id, nil
}
func (h *handler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if api.ValidateUUID(channel) == nil {
return channel, nil
}
c, err := h.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return c.Entity.Id, nil
}
+3 -2
View File
@@ -25,6 +25,7 @@ import (
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -74,10 +75,10 @@ func newHandler() session.Handler {
authn = new(authnmocks.Authentication)
clients = new(clmocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
domains = new(dmocks.DomainsServiceClient)
publisher = new(mocks.PubSub)
resolver := messaging.NewTopicResolver(channels, domains)
return mhttp.NewHandler(publisher, authn, clients, channels, domains, logger)
return mhttp.NewHandler(publisher, authn, clients, channels, resolver, logger)
}
func TestAuthConnect(t *testing.T) {
+1 -1
View File
@@ -16,7 +16,7 @@ import (
)
// SubjectAllMessages represents subject to subscribe for all the messages.
const SubjectAllMessages = messaging.MsgTopicPrefix + ".>"
const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".>"
func init() {
log.Println("The binary was build using Nats as the message broker")
+1 -1
View File
@@ -16,7 +16,7 @@ import (
)
// SubjectAllMessages represents subject to subscribe for all the messages.
const SubjectAllMessages = messaging.MsgTopicPrefix + ".#"
const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".#"
func init() {
log.Println("The binary was build using RabbitMQ as the message broker")
+110
View File
@@ -0,0 +1,110 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package messaging
import (
"context"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
"github.com/absmach/supermq/pkg/errors"
"github.com/gofrs/uuid/v5"
)
var (
ErrEmptyRouteID = errors.New("empty route or id")
ErrFailedResolveDomain = errors.New("failed to resolve domain route")
ErrFailedResolveChannel = errors.New("failed to resolve channel route")
)
// TopicResolver contains definitions for resolving domain and channel IDs
// from their respective routes from the message topic.
type TopicResolver interface {
Resolve(ctx context.Context, domain, channel string) (domainID string, channelID string, err error)
ResolveTopic(ctx context.Context, topic string) (rtopic string, err error)
}
type resolver struct {
channels grpcChannelsV1.ChannelsServiceClient
domains grpcDomainsV1.DomainsServiceClient
}
// NewTopicResolver creates a new instance of TopicResolver.
func NewTopicResolver(channelsClient grpcChannelsV1.ChannelsServiceClient, domainsClient grpcDomainsV1.DomainsServiceClient) TopicResolver {
return &resolver{
channels: channelsClient,
domains: domainsClient,
}
}
func (r *resolver) Resolve(ctx context.Context, domain, channel string) (string, string, error) {
if domain == "" || channel == "" {
return "", "", ErrEmptyRouteID
}
domainID, err := r.resolveDomain(ctx, domain)
if err != nil {
return "", "", errors.Wrap(ErrFailedResolveDomain, err)
}
channelID, err := r.resolveChannel(ctx, channel, domainID)
if err != nil {
return "", "", errors.Wrap(ErrFailedResolveChannel, err)
}
return domainID, channelID, nil
}
func (r *resolver) ResolveTopic(ctx context.Context, topic string) (string, error) {
domain, channel, subtopic, err := ParseTopic(topic)
if err != nil {
return "", errors.Wrap(ErrMalformedTopic, err)
}
domainID, channelID, err := r.Resolve(ctx, domain, channel)
if err != nil {
return "", err
}
rtopic := EncodeAdapterTopic(domainID, channelID, subtopic)
return rtopic, nil
}
func (r *resolver) resolveDomain(ctx context.Context, domain string) (string, error) {
if validateUUID(domain) == nil {
return domain, nil
}
d, err := r.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: domain,
})
if err != nil {
return "", err
}
return d.Entity.Id, nil
}
func (r *resolver) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if validateUUID(channel) == nil {
return channel, nil
}
c, err := r.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return c.Entity.Id, nil
}
func validateUUID(extID string) (err error) {
id, err := uuid.FromString(extID)
if id.String() != extID || err != nil {
return err
}
return nil
}
+248
View File
@@ -0,0 +1,248 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package messaging_test
import (
"context"
"fmt"
"testing"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
validRoute = "valid-route"
invalidRoute = "invalid-route"
channelID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
topicFmt = "m/%s/c/%s"
)
func setupResolver() (messaging.TopicResolver, *dmocks.DomainsServiceClient, *chmocks.ChannelsServiceClient) {
channels := new(chmocks.ChannelsServiceClient)
domains := new(dmocks.DomainsServiceClient)
resolver := messaging.NewTopicResolver(channels, domains)
return resolver, domains, channels
}
func TestResolve(t *testing.T) {
resolver, domains, channels := setupResolver()
cases := []struct {
desc string
domain string
channel string
domainID string
channelID string
domainsErr error
channelsErr error
err error
}{
{
desc: "valid domainID and channelID",
domain: domainID,
channel: channelID,
domainID: domainID,
channelID: channelID,
err: nil,
},
{
desc: "valid domain route and channel ID",
domain: validRoute,
channel: channelID,
domainID: domainID,
channelID: channelID,
err: nil,
},
{
desc: "valid domain ID and channel route",
domain: domainID,
channel: validRoute,
domainID: domainID,
channelID: channelID,
err: nil,
},
{
desc: "valid domain route and channel route",
domain: validRoute,
channel: validRoute,
domainID: domainID,
channelID: channelID,
err: nil,
},
{
desc: "invalid domain route and valid channel",
domain: invalidRoute,
channel: channelID,
domainID: "",
channelID: "",
domainsErr: svcerr.ErrNotFound,
err: messaging.ErrFailedResolveDomain,
},
{
desc: "valid domain and invalid channel",
domain: domainID,
channel: invalidRoute,
domainID: domainID,
channelID: "",
channelsErr: svcerr.ErrNotFound,
err: messaging.ErrFailedResolveChannel,
},
{
desc: "empty domain",
domain: "",
channel: channelID,
domainID: "",
channelID: "",
err: messaging.ErrEmptyRouteID,
},
{
desc: "empty channel",
domain: domainID,
channel: "",
domainID: domainID,
channelID: "",
err: messaging.ErrEmptyRouteID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
domainsCall := domains.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.domain}).Return(&grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: tc.domainID,
},
}, tc.domainsErr)
channelsCall := channels.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.channel, DomainId: tc.domainID}).Return(&grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: tc.channelID,
},
}, tc.channelsErr)
domainID, channelID, err := resolver.Resolve(context.Background(), tc.domain, tc.channel)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID, "expected domain ID %s, got %s", tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID, "expected channel ID %s, got %s", tc.channelID, channelID)
}
domainsCall.Unset()
channelsCall.Unset()
})
}
}
func TestResolveTopic(t *testing.T) {
resolver, domains, channels := setupResolver()
cases := []struct {
desc string
topic string
domain string
channel string
domainID string
channelID string
domainsErr error
channelsErr error
response string
err error
}{
{
desc: "valid topic with domainID and channelID",
topic: fmt.Sprintf(topicFmt, domainID, channelID),
domain: domainID,
channel: channelID,
domainID: domainID,
channelID: channelID,
response: fmt.Sprintf(topicFmt, domainID, channelID),
err: nil,
},
{
desc: "valid topic with domain route and channel ID",
topic: fmt.Sprintf(topicFmt, validRoute, channelID),
domain: validRoute,
channel: channelID,
domainID: domainID,
channelID: channelID,
response: fmt.Sprintf(topicFmt, domainID, channelID),
err: nil,
},
{
desc: "valid topic with domain ID and channel route",
topic: fmt.Sprintf(topicFmt, domainID, validRoute),
domain: domainID,
channel: validRoute,
domainID: domainID,
channelID: channelID,
response: fmt.Sprintf(topicFmt, domainID, channelID),
err: nil,
},
{
desc: "valid topic with domain route and channel route",
topic: fmt.Sprintf(topicFmt, validRoute, validRoute),
domain: validRoute,
channel: validRoute,
domainID: domainID,
channelID: channelID,
response: fmt.Sprintf(topicFmt, domainID, channelID),
err: nil,
},
{
desc: "invalid topic with invalid domain route and valid channel",
topic: fmt.Sprintf(topicFmt, invalidRoute, channelID),
domain: invalidRoute,
channel: channelID,
domainID: "",
channelID: "",
domainsErr: svcerr.ErrNotFound,
err: messaging.ErrFailedResolveDomain,
},
{
desc: "valid topic with valid topic with domainID and channelID and subtopic",
topic: fmt.Sprintf(topicFmt, domainID, channelID) + "/subtopic",
domain: domainID,
channel: channelID,
domainID: domainID,
channelID: channelID,
response: fmt.Sprintf(topicFmt, domainID, channelID) + "/subtopic",
err: nil,
},
{
desc: "invalid topic with empty domain",
topic: fmt.Sprintf(topicFmt, "", channelID),
domain: "",
channel: channelID,
domainID: "",
channelID: "",
err: messaging.ErrMalformedTopic,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
domainsCall := domains.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.domain}).Return(&grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: tc.domainID,
},
}, tc.domainsErr)
channelsCall := channels.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.channel, DomainId: tc.domainID}).Return(&grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: tc.channelID,
},
}, tc.channelsErr)
rtopic, err := resolver.ResolveTopic(context.Background(), tc.topic)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.response, rtopic, "expected topic %s, got %s", tc.response, rtopic)
}
domainsCall.Unset()
channelsCall.Unset()
})
}
}
+86 -28
View File
@@ -6,27 +6,19 @@ package messaging
import (
"fmt"
"net/url"
"regexp"
"strings"
"github.com/absmach/supermq/pkg/errors"
)
const (
MsgTopicPrefix = "m"
ChannelTopicPrefix = "c"
numGroups = 4 // entire expression + domain group + channel group + subtopic group
domainGroup = 1 // domain group is first in msg topic regexp
channelGroup = 2 // channel group is second in msg topic regexp
subtopicGroup = 3 // subtopic group is third in msg topic regexp
MsgTopicPrefix = 'm'
ChannelTopicPrefix = 'c'
)
var (
ErrMalformedTopic = errors.New("malformed topic")
ErrMalformedSubtopic = errors.New("malformed subtopic")
// Regex to group topic in format m.<domain_id>.c.<channel_id>.<sub_topic> `^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`.
TopicRegExp = regexp.MustCompile(`^\/?` + MsgTopicPrefix + `\/([\w\-]+)\/` + ChannelTopicPrefix + `\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
mqWildcards = "+#"
wildcards = "*>"
subtopicInvalidChars = " #+"
@@ -35,15 +27,10 @@ var (
)
func ParsePublishTopic(topic string) (domainID, chanID, subtopic string, err error) {
msgParts := TopicRegExp.FindStringSubmatch(topic)
if len(msgParts) < numGroups {
return "", "", "", ErrMalformedTopic
domainID, chanID, subtopic, err = ParseTopic(topic)
if err != nil {
return "", "", "", err
}
domainID = msgParts[domainGroup]
chanID = msgParts[channelGroup]
subtopic = msgParts[subtopicGroup]
subtopic, err = ParsePublishSubtopic(subtopic)
if err != nil {
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
@@ -74,14 +61,10 @@ func ParsePublishSubtopic(subtopic string) (parseSubTopic string, err error) {
}
func ParseSubscribeTopic(topic string) (domainID string, chanID string, subtopic string, err error) {
msgParts := TopicRegExp.FindStringSubmatch(topic)
if len(msgParts) < numGroups {
return "", "", "", ErrMalformedTopic
domainID, chanID, subtopic, err = ParseTopic(topic)
if err != nil {
return "", "", "", err
}
domainID = msgParts[domainGroup]
chanID = msgParts[channelGroup]
subtopic = msgParts[subtopicGroup]
subtopic, err = ParseSubscribeSubtopic(subtopic)
if err != nil {
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
@@ -132,11 +115,11 @@ func formatSubtopic(subtopic string) (string, error) {
}
func EncodeTopic(domainID string, channelID string, subtopic string) string {
return fmt.Sprintf("%s.%s", MsgTopicPrefix, EncodeTopicSuffix(domainID, channelID, subtopic))
return fmt.Sprintf("%s.%s", string(MsgTopicPrefix), EncodeTopicSuffix(domainID, channelID, subtopic))
}
func EncodeTopicSuffix(domainID string, channelID string, subtopic string) string {
subject := fmt.Sprintf("%s.%s.%s", domainID, ChannelTopicPrefix, channelID)
subject := fmt.Sprintf("%s.%s.%s", domainID, string(ChannelTopicPrefix), channelID)
if subtopic != "" {
subject = fmt.Sprintf("%s.%s", subject, subtopic)
}
@@ -148,9 +131,84 @@ func EncodeMessageTopic(m *Message) string {
}
func EncodeMessageMQTTTopic(m *Message) string {
topic := fmt.Sprintf("%s/%s/%s/%s", MsgTopicPrefix, m.GetDomain(), ChannelTopicPrefix, m.GetChannel())
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), m.GetDomain(), string(ChannelTopicPrefix), m.GetChannel())
if m.GetSubtopic() != "" {
topic = topic + "/" + strings.ReplaceAll(m.GetSubtopic(), ".", "/")
}
return topic
}
func EncodeAdapterTopic(domain, channel, subtopic string) string {
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), domain, string(ChannelTopicPrefix), channel)
if subtopic != "" {
topic = topic + "/" + subtopic
}
return topic
}
// ParseTopic parses a messaging topic string and returns the domain ID, channel ID, and subtopic.
// This is an optimized version with no regex and minimal allocations.
func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
// location of string "m"
start := 0
// Handle both formats: "/m/domain/c/channel/subtopic" and "m/domain/c/channel/subtopic".
// If topic start with m/ then start is 0 , If topic start with /m/ then start is 1.
n := len(topic)
if n > 0 && topic[0] == '/' {
start = 1
}
// length check - minimum: "m/<domain_id>/c/" = 5 characters if ignore <domain_id> and in this case start will be 0
// length check - minimum: "/m/<domain_id>/c/" = 6 characters if ignore <domain_id> and in this case start will be 1
if n < start+5 {
return "", "", "", ErrMalformedTopic
}
if topic[start] != MsgTopicPrefix || topic[start+1] != '/' {
return "", "", "", ErrMalformedTopic
}
pos := start + 2
// Find "/c/" to locate domain ID
cPos := -1
for i := pos; i <= n-3; i++ {
if topic[i] == '/' && topic[i+1] == ChannelTopicPrefix && topic[i+2] == '/' {
cPos = i - pos
break
}
}
if cPos == -1 || cPos == 0 {
return "", "", "", ErrMalformedTopic
}
domainID = topic[pos : pos+cPos]
// skip "/c/"
pos = pos + cPos + 3
// Ensure channel exists
if pos >= n {
return "", "", "", ErrMalformedTopic
}
// Find '/' after channelID
nextSlash := -1
for i := pos; i < n; i++ {
if topic[i] == '/' {
nextSlash = i - pos
break
}
}
if nextSlash == -1 {
// No subtopic
chanID = topic[pos:]
} else {
chanID = topic[pos : pos+nextSlash]
subtopic = topic[pos+nextSlash+1:]
}
// Validate channelID
if len(chanID) == 0 {
return "", "", "", ErrMalformedTopic
}
return domainID, chanID, subtopic, nil
}
+70 -33
View File
@@ -6,6 +6,7 @@ package messaging_test
import (
"testing"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/stretchr/testify/assert"
)
@@ -16,7 +17,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID string
channelID string
subtopic string
expectErr bool
err error
}{
{
desc: "valid topic with subtopic /m/domain123/c/channel456/devices/temp",
@@ -24,6 +25,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
err: nil,
},
{
desc: "valid topic with URL encoded subtopic /m/domain123/c/channel456/devices%2Ftemp%2Fdata",
@@ -46,13 +48,20 @@ var ParsePublisherTopicTestCases = []struct {
channelID: "channel456",
subtopic: "",
},
{
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
topic: "/m/domain123/c/channel456/devices/temp/",
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
},
{
desc: "invalid topic format (missing parts) /m/domain123/c/",
topic: "/m/domain123/c/",
domainID: "domain123",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing domain) /m//c/channel123",
@@ -60,7 +69,15 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing channel) /m/domain123/c/",
topic: "/m/domain123/c//subtopic",
domainID: "domain123",
channelID: "",
subtopic: "",
err: messaging.ErrMalformedTopic,
},
{
desc: "topic with wildcards + and # /m/domain123/c/channel456/devices/+/temp/#",
@@ -68,7 +85,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid domain name m/domain*123/c/channel456/devices/+/temp/#",
@@ -76,7 +93,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "",
channelID: "channel456",
subtopic: "devices.*.temp.>",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a*b/topic",
@@ -84,7 +101,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a>b/topic",
@@ -92,7 +109,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a#b/topic",
@@ -100,7 +117,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a+b/topic",
@@ -108,7 +125,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a//b/topic",
@@ -116,7 +133,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic regex \"not-a-topic\"",
@@ -124,7 +141,12 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "extra segment before prefix /extra/m/domain/c/channel",
topic: "/extra/m/domain/c/channel",
err: messaging.ErrMalformedTopic,
},
}
@@ -132,10 +154,8 @@ func TestParsePublishTopic(t *testing.T) {
for _, tc := range ParsePublisherTopicTestCases {
t.Run(tc.desc, func(t *testing.T) {
domainID, channelID, subtopic, err := messaging.ParsePublishTopic(tc.topic)
if tc.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID)
assert.Equal(t, tc.subtopic, subtopic)
@@ -147,7 +167,7 @@ func TestParsePublishTopic(t *testing.T) {
func BenchmarkParsePublisherTopic(b *testing.B) {
for _, tc := range ParsePublisherTopicTestCases {
b.Run(tc.desc, func(b *testing.B) {
for i := 0; i < b.N; i++ {
for b.Loop() {
_, _, _, _ = messaging.ParsePublishTopic(tc.topic)
}
})
@@ -160,7 +180,7 @@ var ParseSubscribeTestCases = []struct {
domainID string
channelID string
subtopic string
expectErr bool
err error
}{
{
desc: "valid topic with subtopic /m/domain123/c/channel456/devices/temp",
@@ -183,13 +203,20 @@ var ParseSubscribeTestCases = []struct {
channelID: "channel456",
subtopic: "",
},
{
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
topic: "/m/domain123/c/channel456/devices/temp/",
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
},
{
desc: "invalid topic format (missing channel) /m/domain123/c/",
topic: "/m/domain123/c/",
domainID: "domain123",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing domain) /m//c/channel123",
@@ -197,15 +224,22 @@ var ParseSubscribeTestCases = []struct {
domainID: "",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing channel) /m/domain123/c/",
topic: "/m/domain123/c//subtopic",
domainID: "domain123",
channelID: "",
subtopic: "",
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid domain name m/domain*123/c/channel456/devices/+/temp/#",
topic: "m/domain*123/c/channel456/devices/+/temp/#",
domainID: "",
domainID: "domain*123",
channelID: "channel456",
subtopic: "devices.*.temp.>",
expectErr: true,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a*b/topic",
@@ -213,7 +247,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a>b/topic",
@@ -221,7 +255,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a#b/topic",
@@ -229,7 +263,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a+b/topic",
@@ -237,7 +271,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a//b/topic",
@@ -245,7 +279,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a/ /b/topic",
@@ -253,7 +287,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "completely invalid topic \"invalid-topic\"",
@@ -261,7 +295,12 @@ var ParseSubscribeTestCases = []struct {
domainID: "",
channelID: "",
subtopic: "",
expectErr: true,
err: messaging.ErrMalformedTopic,
},
{
desc: "extra segment before prefix /extra/m/domain/c/channel",
topic: "/extra/m/domain/c/channel",
err: messaging.ErrMalformedTopic,
},
}
@@ -269,10 +308,8 @@ func TestParseSubscribeTopic(t *testing.T) {
for _, tc := range ParseSubscribeTestCases {
t.Run(tc.desc, func(t *testing.T) {
domainID, channelID, subtopic, err := messaging.ParseSubscribeTopic(tc.topic)
if tc.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID)
assert.Equal(t, tc.subtopic, subtopic)
@@ -284,7 +321,7 @@ func TestParseSubscribeTopic(t *testing.T) {
func BenchmarkParseSubscribeTopic(b *testing.B) {
for _, tc := range ParseSubscribeTestCases {
b.Run(tc.desc, func(b *testing.B) {
for i := 0; i < b.N; i++ {
for b.Loop() {
_, _, _, _ = messaging.ParseSubscribeTopic(tc.topic)
}
})
+3 -1
View File
@@ -28,6 +28,7 @@ import (
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
pubsub "github.com/absmach/supermq/pkg/messaging/mocks"
sdk "github.com/absmach/supermq/pkg/sdk"
"github.com/stretchr/testify/assert"
@@ -46,7 +47,8 @@ func setupMessages() (*httptest.Server, *pubsub.PubSub) {
domainsGRPCClient = new(dmocks.DomainsServiceClient)
pub := new(pubsub.PubSub)
authn := new(authnmocks.Authentication)
handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, domainsGRPCClient, smqlog.NewMock())
resolver := messaging.NewTopicResolver(channelsGRPCClient, domainsGRPCClient)
handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, resolver, smqlog.NewMock())
mux := api.MakeHandler(smqlog.NewMock(), "")
target := httptest.NewServer(mux)
+4 -59
View File
@@ -9,9 +9,6 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
api "github.com/absmach/supermq/api/http"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -26,10 +23,6 @@ var (
ErrFailedSubscribe = errors.New("failed to unsubscribe from topic")
// ErrEmptyTopic indicate absence of clientKey in the request.
ErrEmptyTopic = errors.New("empty topic")
// errFailedResolveDomain indicates that the domain route could not be resolved.
errFailedResolveDomain = errors.New("failed to resolve domain route")
// errFailedResolveChannel indicates that the channel route could not be resolved.
errFailedResolveChannel = errors.New("failed to resolve channel route")
)
// Service specifies web socket service API.
@@ -48,34 +41,23 @@ var _ Service = (*adapterService)(nil)
type adapterService struct {
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
domains grpcDomainsV1.DomainsServiceClient
pubsub messaging.PubSub
}
// New instantiates the WS adapter implementation.
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, pubsub messaging.PubSub) Service {
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service {
return &adapterService{
clients: clients,
channels: channels,
domains: domains,
pubsub: pubsub,
}
}
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domain, channel, subtopic string, c *Client) error {
if channel == "" || clientKey == "" || domain == "" {
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, channelID, subtopic string, c *Client) error {
if channelID == "" || clientKey == "" || domainID == "" {
return svcerr.ErrAuthentication
}
domainID, err := svc.resolveDomain(ctx, domain)
if err != nil {
return errFailedResolveDomain
}
channelID, err := svc.resolveChannel(ctx, channel, domainID)
if err != nil {
return errFailedResolveChannel
}
clientID, err := svc.authorize(ctx, clientKey, domainID, channelID, connections.Subscribe)
if err != nil {
return svcerr.ErrAuthorization
@@ -97,15 +79,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey,
return nil
}
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domain, channel, subtopic string) error {
domainID, err := svc.resolveDomain(ctx, domain)
if err != nil {
return errors.Wrap(errFailedResolveDomain, err)
}
channelID, err := svc.resolveChannel(ctx, channel, domainID)
if err != nil {
return errors.Wrap(errFailedResolveChannel, err)
}
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID, channelID, subtopic string) error {
topic := messaging.EncodeTopic(domainID, channelID, subtopic)
if err := svc.pubsub.Unsubscribe(ctx, sessionID, topic); err != nil {
@@ -148,32 +122,3 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c
return authnRes.GetId(), nil
}
func (svc *adapterService) resolveDomain(ctx context.Context, domain string) (string, error) {
if api.ValidateUUID(domain) == nil {
return domain, nil
}
d, err := svc.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: domain,
})
if err != nil {
return "", err
}
return d.Entity.Id, nil
}
func (svc *adapterService) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if api.ValidateUUID(channel) == nil {
return channel, nil
}
c, err := svc.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return c.Entity.Id, nil
}
+3 -10
View File
@@ -12,10 +12,8 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -52,17 +50,16 @@ var (
sessionID = "sessionID"
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *dmocks.DomainsServiceClient) {
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
pubsub := new(mocks.PubSub)
clients := new(climocks.ClientsServiceClient)
channels := new(chmocks.ChannelsServiceClient)
domains := new(dmocks.DomainsServiceClient)
return ws.New(clients, channels, domains, pubsub), pubsub, clients, channels, domains
return ws.New(clients, channels, pubsub), pubsub, clients, channels
}
func TestSubscribe(t *testing.T) {
svc, pubsub, clients, channels, domains := newService()
svc, pubsub, clients, channels := newService()
c := ws.NewClient(slog.Default(), nil, sessionID)
@@ -197,7 +194,6 @@ func TestSubscribe(t *testing.T) {
if strings.HasPrefix(tc.clientKey, "Client") {
authReq.ClientSecret = strings.TrimPrefix(tc.clientKey, "Client ")
}
domainsCall := domains.On("RetrieveByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: policies.ClientType,
@@ -206,14 +202,11 @@ func TestSubscribe(t *testing.T) {
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
channelsCall1 := channels.On("RetrieveByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.chanID}}, nil)
repocall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
clientsCall.Unset()
channelsCall.Unset()
domainsCall.Unset()
channelsCall1.Unset()
}
}
+9 -8
View File
@@ -18,7 +18,6 @@ import (
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
@@ -26,6 +25,7 @@ import (
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/ws"
"github.com/absmach/supermq/ws/api"
@@ -47,13 +47,13 @@ var (
id = testsutil.GenerateUUID(&testing.T{})
)
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (ws.Service, *mocks.PubSub) {
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
pubsub := new(mocks.PubSub)
return ws.New(clients, channels, domains, pubsub), pubsub
return ws.New(clients, channels, pubsub), pubsub
}
func newHTTPServer(svc ws.Service) *httptest.Server {
mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), instanceID)
func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server {
mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID)
return httptest.NewServer(mux)
}
@@ -116,10 +116,11 @@ func TestHandshake(t *testing.T) {
channels := new(chmocks.ChannelsServiceClient)
authn := new(authnMocks.Authentication)
domains := new(dmocks.DomainsServiceClient)
svc, pubsub := newService(clients, channels, domains)
target := newHTTPServer(svc)
resolver := messaging.NewTopicResolver(channels, domains)
svc, pubsub := newService(clients, channels)
target := newHTTPServer(svc, resolver)
defer target.Close()
handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, domains)
handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, resolver)
ts, err := newProxyHTPPServer(handler, target)
require.Nil(t, err)
defer ts.Close()
+13 -8
View File
@@ -27,9 +27,9 @@ func generateSessionID() (string, error) {
return hex.EncodeToString(b), nil
}
func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.HandlerFunc {
func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
req, err := decodeRequest(r, logger)
req, err := decodeRequest(r, resolver, logger)
if err != nil {
encodeError(w, err)
return
@@ -51,21 +51,21 @@ func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.Ha
client := ws.NewClient(logger, conn, sessionID)
client.SetCloseHandler(func(code int, text string) error {
return svc.Unsubscribe(ctx, sessionID, req.domain, req.channel, req.subtopic)
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic)
})
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domain, req.channel, req.subtopic, client); err != nil {
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
conn.Close()
return
}
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channel))
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID))
}
}
func decodeRequest(r *http.Request, logger *slog.Logger) (connReq, error) {
func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
authKey := r.Header.Get("Authorization")
if authKey == "" {
authKeys := r.URL.Query()["authorization"]
@@ -79,10 +79,15 @@ func decodeRequest(r *http.Request, logger *slog.Logger) (connReq, error) {
domain := chi.URLParam(r, "domain")
channel := chi.URLParam(r, "channel")
domainID, channelID, err := resolver.Resolve(r.Context(), domain, channel)
if err != nil {
return connReq{}, err
}
req := connReq{
clientKey: authKey,
channel: channel,
domain: domain,
channelID: channelID,
domainID: domainID,
}
subTopic := chi.URLParam(r, "*")
+2 -2
View File
@@ -5,7 +5,7 @@ package api
type connReq struct {
clientKey string
channel string
domain string
channelID string
domainID string
subtopic string
}
+4 -3
View File
@@ -10,6 +10,7 @@ import (
"net/http"
"github.com/absmach/supermq"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/ws"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
@@ -36,12 +37,12 @@ var (
)
// MakeHandler returns http handler with handshake endpoint.
func MakeHandler(ctx context.Context, svc ws.Service, l *slog.Logger, instanceID string) http.Handler {
func MakeHandler(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, l *slog.Logger, instanceID string) http.Handler {
logger = l
mux := chi.NewRouter()
mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, l))
mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, l))
mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, resolver, l))
mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, resolver, l))
mux.Get("/health", supermq.Health(service, instanceID))
mux.Handle("/metrics", promhttp.Handler())
+12 -57
View File
@@ -15,9 +15,6 @@ import (
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
@@ -53,20 +50,20 @@ type handler struct {
pubsub messaging.PubSub
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
domains grpcDomainsV1.DomainsServiceClient
authn smqauthn.Authentication
logger *slog.Logger
resolver messaging.TopicResolver
}
// NewHandler creates new Handler entity.
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) session.Handler {
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver) session.Handler {
return &handler{
logger: logger,
pubsub: pubsub,
authn: authn,
clients: clients,
channels: channels,
domains: domains,
resolver: resolver,
}
}
@@ -97,18 +94,14 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
domain, channel, _, err := messaging.ParsePublishTopic(*topic)
if err != nil {
return err
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
domainID, err := h.resolveDomain(ctx, domain)
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
}
chanID, err := h.resolveChannel(ctx, channel, domainID)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
clientID, clientType, err := h.authAccess(ctx, token, domainID, chanID, connections.Publish)
clientID, clientType, err := h.authAccess(ctx, token, domainID, channelID, connections.Publish)
if err != nil {
return err
}
@@ -136,19 +129,14 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
if err != nil {
return err
}
domainID, err := h.resolveDomain(ctx, domain)
domainID, chanID, err := h.resolver.Resolve(ctx, domain, channel)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
}
chanID, err := h.resolveChannel(ctx, channel, domainID)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
if _, _, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Subscribe); err != nil {
return err
}
}
return nil
}
@@ -173,19 +161,15 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
if err != nil {
return errors.Wrap(errFailedPublish, err)
}
domainID, err := h.resolveDomain(ctx, domain)
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
}
chanID, err := h.resolveChannel(ctx, channel, domainID)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: chanID,
Channel: channelID,
Subtopic: subtopic,
Payload: *payload,
Publisher: s.Username,
@@ -257,35 +241,6 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
return clientID, clientType, nil
}
func (h *handler) resolveDomain(ctx context.Context, domain string) (string, error) {
if api.ValidateUUID(domain) == nil {
return domain, nil
}
d, err := h.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: domain,
})
if err != nil {
return "", err
}
return d.Entity.Id, nil
}
func (h *handler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
if api.ValidateUUID(channel) == nil {
return channel, nil
}
c, err := h.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
Route: channel,
DomainId: domainID,
})
if err != nil {
return "", err
}
return c.Entity.Id, nil
}
// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned.
func extractClientSecret(token string) string {
if !strings.HasPrefix(token, apiutil.ClientPrefix) {