mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-2956 - Add routes resolver to abstract topic resolution logic (#2960)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
+3
-1
@@ -20,6 +20,7 @@ import (
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/messaging/brokers"
|
||||
brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing"
|
||||
msgevents "github.com/absmach/supermq/pkg/messaging/events"
|
||||
@@ -181,7 +182,8 @@ func main() {
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(cfg.InstanceID), logger)
|
||||
|
||||
cs := coapserver.NewServer(ctx, cancel, svcName, coapServerConfig, httpapi.MakeCoAPHandler(svc, channelsClient, domainsClient, logger), logger)
|
||||
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
|
||||
cs := coapserver.NewServer(ctx, cancel, svcName, coapServerConfig, httpapi.MakeCoAPHandler(svc, channelsClient, resolver, logger), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
|
||||
+2
-1
@@ -221,7 +221,8 @@ func main() {
|
||||
}
|
||||
|
||||
func newService(pub messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) session.Handler {
|
||||
svc := adapter.NewHandler(pub, authn, clients, channels, domains, logger)
|
||||
resolver := messaging.NewTopicResolver(channels, domains)
|
||||
svc := adapter.NewHandler(pub, authn, clients, channels, resolver, logger)
|
||||
svc = handler.NewTracing(tracer, svc)
|
||||
svc = handler.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics(svcName, "api")
|
||||
|
||||
+5
-59
@@ -23,10 +23,6 @@ import (
|
||||
"github.com/absmach/mgate/pkg/mqtt/websocket"
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
"github.com/absmach/supermq"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/mqtt"
|
||||
"github.com/absmach/supermq/mqtt/events"
|
||||
@@ -57,11 +53,6 @@ const (
|
||||
wsPathPrefix = "/mqtt"
|
||||
)
|
||||
|
||||
var (
|
||||
errFailedResolveDomain = errors.New("failed to resolve domain route")
|
||||
errFailedResolveChannel = errors.New("failed to resolve channel route")
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_MQTT_ADAPTER_LOG_LEVEL" envDefault:"info"`
|
||||
MQTTPort string `env:"SMQ_MQTT_ADAPTER_MQTT_PORT" envDefault:"1883"`
|
||||
@@ -247,8 +238,7 @@ func main() {
|
||||
}
|
||||
|
||||
beforeHandler := beforeHandler{
|
||||
domains: domainsClient,
|
||||
channels: channelsClient,
|
||||
resolver: messaging.NewTopicResolver(channelsClient, domainsClient),
|
||||
}
|
||||
|
||||
afterHandler := afterHandler{
|
||||
@@ -380,8 +370,7 @@ func (ah afterHandler) Intercept(ctx context.Context, pkt packets.ControlPacket,
|
||||
}
|
||||
|
||||
type beforeHandler struct {
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
resolver messaging.TopicResolver
|
||||
}
|
||||
|
||||
// This interceptor is used to replace domain and channel routes with relevant domain and channel IDs in the message topic.
|
||||
@@ -389,7 +378,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
|
||||
switch pt := pkt.(type) {
|
||||
case *packets.SubscribePacket:
|
||||
for i, topic := range pt.Topics {
|
||||
ft, err := bh.resolveTopic(ctx, topic)
|
||||
ft, err := bh.resolver.ResolveTopic(ctx, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -399,7 +388,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
|
||||
return pt, nil
|
||||
case *packets.UnsubscribePacket:
|
||||
for i, topic := range pt.Topics {
|
||||
ft, err := bh.resolveTopic(ctx, topic)
|
||||
ft, err := bh.resolver.ResolveTopic(ctx, topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -407,7 +396,7 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
|
||||
}
|
||||
return pt, nil
|
||||
case *packets.PublishPacket:
|
||||
ft, err := bh.resolveTopic(ctx, pt.TopicName)
|
||||
ft, err := bh.resolver.ResolveTopic(ctx, pt.TopicName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -418,46 +407,3 @@ func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket
|
||||
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func (bh beforeHandler) resolveTopic(ctx context.Context, topic string) (string, error) {
|
||||
matches := messaging.TopicRegExp.FindStringSubmatch(topic)
|
||||
if len(matches) < 4 {
|
||||
return "", messaging.ErrMalformedTopic
|
||||
}
|
||||
|
||||
domainID, err := bh.resolveDomain(ctx, matches[1])
|
||||
if err != nil {
|
||||
return "", errors.Wrap(errFailedResolveDomain, err)
|
||||
}
|
||||
channelID, err := bh.resolveChannel(ctx, matches[2], domainID)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(errFailedResolveChannel, err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("m/%s/c/%s%s", domainID, channelID, matches[3]), nil
|
||||
}
|
||||
|
||||
func (bh beforeHandler) resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if api.ValidateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
resp, err := bh.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{Route: domain})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Entity.Id, nil
|
||||
}
|
||||
|
||||
func (bh beforeHandler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if api.ValidateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
resp, err := bh.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Entity.Id, nil
|
||||
}
|
||||
|
||||
+6
-6
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/absmach/supermq"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
@@ -194,10 +193,11 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
resolver := messaging.NewTopicResolver(channelsClient, domainsClient)
|
||||
|
||||
svc := newService(clientsClient, channelsClient, domainsClient, nps, logger, tracer)
|
||||
svc := newService(clientsClient, channelsClient, nps, logger, tracer)
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, logger, cfg.InstanceID), logger)
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
@@ -209,7 +209,7 @@ func main() {
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, domainsClient)
|
||||
handler := ws.NewHandler(nps, logger, authn, clientsClient, channelsClient, resolver)
|
||||
return proxyWS(ctx, httpServerConfig, targetServerConfig, logger, handler)
|
||||
})
|
||||
|
||||
@@ -222,8 +222,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
|
||||
svc := ws.New(clientsClient, channels, domains, nps)
|
||||
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
|
||||
svc := ws.New(clientsClient, channels, nps)
|
||||
svc = tracing.New(tracer, svc)
|
||||
svc = httpapi.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics("ws_adapter", "api")
|
||||
|
||||
+36
-74
@@ -14,9 +14,6 @@ import (
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
"github.com/absmach/supermq/coap"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -36,17 +33,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errBadOptions = errors.New("bad options")
|
||||
errMethodNotAllowed = errors.New("method not allowed")
|
||||
errFailedResolveDomain = errors.New("failed to resolve domain route")
|
||||
errFailedResolveChannel = errors.New("failed to resolve channel route")
|
||||
)
|
||||
|
||||
var (
|
||||
logger *slog.Logger
|
||||
service coap.Service
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
errBadOptions = errors.New("bad options")
|
||||
errMethodNotAllowed = errors.New("method not allowed")
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
@@ -58,39 +46,47 @@ func MakeHandler(instanceID string) http.Handler {
|
||||
return b
|
||||
}
|
||||
|
||||
// MakeCoAPHandler creates handler for CoAP messages.
|
||||
func MakeCoAPHandler(svc coap.Service, channelsClient grpcChannelsV1.ChannelsServiceClient, domainsClient grpcDomainsV1.DomainsServiceClient, l *slog.Logger) mux.HandlerFunc {
|
||||
logger = l
|
||||
service = svc
|
||||
channels = channelsClient
|
||||
domains = domainsClient
|
||||
|
||||
return handler
|
||||
type CoAPHandler struct {
|
||||
logger *slog.Logger
|
||||
service coap.Service
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
resolver messaging.TopicResolver
|
||||
}
|
||||
|
||||
func sendResp(w mux.ResponseWriter, resp *pool.Message) {
|
||||
// MakeCoAPHandler creates handler for CoAP messages.
|
||||
func MakeCoAPHandler(svc coap.Service, channelsClient grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver, l *slog.Logger) mux.HandlerFunc {
|
||||
h := &CoAPHandler{
|
||||
logger: l,
|
||||
service: svc,
|
||||
channels: channelsClient,
|
||||
resolver: resolver,
|
||||
}
|
||||
return h.handler
|
||||
}
|
||||
|
||||
func (h *CoAPHandler) sendResp(w mux.ResponseWriter, resp *pool.Message) {
|
||||
if err := w.Conn().WriteMessage(resp); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Can't set response: %s", err))
|
||||
h.logger.Warn(fmt.Sprintf("Can't set response: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func handler(w mux.ResponseWriter, m *mux.Message) {
|
||||
func (h *CoAPHandler) handler(w mux.ResponseWriter, m *mux.Message) {
|
||||
resp := pool.NewMessage(w.Conn().Context())
|
||||
resp.SetToken(m.Token())
|
||||
for _, opt := range m.Options() {
|
||||
resp.AddOptionBytes(opt.ID, opt.Value)
|
||||
}
|
||||
defer sendResp(w, resp)
|
||||
defer h.sendResp(w, resp)
|
||||
|
||||
msg, err := decodeMessage(m)
|
||||
msg, err := h.decodeMessage(m)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Error decoding message: %s", err))
|
||||
h.logger.Warn(fmt.Sprintf("Error decoding message: %s", err))
|
||||
resp.SetCode(codes.BadRequest)
|
||||
return
|
||||
}
|
||||
key, err := parseKey(m)
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Error parsing auth: %s", err))
|
||||
h.logger.Warn(fmt.Sprintf("Error parsing auth: %s", err))
|
||||
resp.SetCode(codes.Unauthorized)
|
||||
return
|
||||
}
|
||||
@@ -98,10 +94,10 @@ func handler(w mux.ResponseWriter, m *mux.Message) {
|
||||
switch m.Code() {
|
||||
case codes.GET:
|
||||
resp.SetCode(codes.Content)
|
||||
err = handleGet(m, w, msg, key)
|
||||
err = h.handleGet(m, w, msg, key)
|
||||
case codes.POST:
|
||||
resp.SetCode(codes.Created)
|
||||
err = service.Publish(m.Context(), key, msg)
|
||||
err = h.service.Publish(m.Context(), key, msg)
|
||||
default:
|
||||
err = errMethodNotAllowed
|
||||
}
|
||||
@@ -122,24 +118,24 @@ func handler(w mux.ResponseWriter, m *mux.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
func handleGet(m *mux.Message, w mux.ResponseWriter, msg *messaging.Message, key string) error {
|
||||
func (h *CoAPHandler) handleGet(m *mux.Message, w mux.ResponseWriter, msg *messaging.Message, key string) error {
|
||||
var obs uint32
|
||||
obs, err := m.Options().Observe()
|
||||
if err != nil {
|
||||
logger.Warn(fmt.Sprintf("Error reading observe option: %s", err))
|
||||
h.logger.Warn(fmt.Sprintf("Error reading observe option: %s", err))
|
||||
return errBadOptions
|
||||
}
|
||||
if obs == startObserve {
|
||||
c := coap.NewClient(w.Conn(), m.Token(), logger)
|
||||
c := coap.NewClient(w.Conn(), m.Token(), h.logger)
|
||||
w.Conn().AddOnClose(func() {
|
||||
_ = service.DisconnectHandler(context.Background(), msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c.Token())
|
||||
_ = h.service.DisconnectHandler(context.Background(), msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c.Token())
|
||||
})
|
||||
return service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c)
|
||||
return h.service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c)
|
||||
}
|
||||
return service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
|
||||
return h.service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
|
||||
}
|
||||
|
||||
func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
|
||||
func (h *CoAPHandler) decodeMessage(msg *mux.Message) (*messaging.Message, error) {
|
||||
if msg.Options() == nil {
|
||||
return &messaging.Message{}, errBadOptions
|
||||
}
|
||||
@@ -159,14 +155,9 @@ func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
|
||||
return &messaging.Message{}, err
|
||||
}
|
||||
|
||||
domainID, err := resolveDomain(msg.Context(), domain)
|
||||
domainID, channelID, err := h.resolver.Resolve(msg.Context(), domain, channel)
|
||||
if err != nil {
|
||||
return &messaging.Message{}, errors.Wrap(errFailedResolveDomain, err)
|
||||
}
|
||||
|
||||
channelID, err := resolveChannel(msg.Context(), channel, domainID)
|
||||
if err != nil {
|
||||
return &messaging.Message{}, errors.Wrap(errFailedResolveChannel, err)
|
||||
return &messaging.Message{}, err
|
||||
}
|
||||
|
||||
ret := &messaging.Message{
|
||||
@@ -199,32 +190,3 @@ func parseKey(msg *mux.Message) (string, error) {
|
||||
}
|
||||
return vars[1], nil
|
||||
}
|
||||
|
||||
func resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if api.ValidateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
d, err := domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: domain,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return d.Entity.Id, nil
|
||||
}
|
||||
|
||||
func resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if api.ValidateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
c, err := channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.Entity.Id, nil
|
||||
}
|
||||
|
||||
@@ -51,7 +51,8 @@ var (
|
||||
|
||||
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub) {
|
||||
pub := new(pubsub.PubSub)
|
||||
return server.NewHandler(pub, authn, clients, channels, domains, smqlog.NewMock()), pub
|
||||
resolver := messaging.NewTopicResolver(channels, domains)
|
||||
return server.NewHandler(pub, authn, clients, channels, resolver, smqlog.NewMock()), pub
|
||||
}
|
||||
|
||||
func newTargetHTTPServer() *httptest.Server {
|
||||
|
||||
+5
-43
@@ -15,9 +15,6 @@ import (
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
@@ -52,8 +49,6 @@ var (
|
||||
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
|
||||
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
|
||||
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
|
||||
errFailedResolveDomain = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to resolve domain route"))
|
||||
errFailedResolveChannel = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to resolve channel route"))
|
||||
)
|
||||
|
||||
// Event implements events.Event interface.
|
||||
@@ -61,19 +56,19 @@ type handler struct {
|
||||
publisher messaging.Publisher
|
||||
clients grpcClientsV1.ClientsServiceClient
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
resolver messaging.TopicResolver
|
||||
authn smqauthn.Authentication
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewHandler creates new Handler entity.
|
||||
func NewHandler(publisher messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger) session.Handler {
|
||||
func NewHandler(publisher messaging.Publisher, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver, logger *slog.Logger) session.Handler {
|
||||
return &handler{
|
||||
publisher: publisher,
|
||||
authn: authn,
|
||||
clients: clients,
|
||||
channels: channels,
|
||||
domains: domains,
|
||||
resolver: resolver,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -130,13 +125,9 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
if err != nil {
|
||||
return errors.Wrap(errMalformedTopic, err)
|
||||
}
|
||||
domainID, err := h.resolveDomain(ctx, domain)
|
||||
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedResolveDomain, err)
|
||||
}
|
||||
channelID, err := h.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedResolveChannel, err)
|
||||
return errors.Wrap(errFailedPublish, err)
|
||||
}
|
||||
|
||||
var clientID, clientType string
|
||||
@@ -218,32 +209,3 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
|
||||
func (h *handler) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if api.ValidateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
d, err := h.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: domain,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return d.Entity.Id, nil
|
||||
}
|
||||
|
||||
func (h *handler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if api.ValidateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
c, err := h.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.Entity.Id, nil
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@@ -74,10 +75,10 @@ func newHandler() session.Handler {
|
||||
authn = new(authnmocks.Authentication)
|
||||
clients = new(clmocks.ClientsServiceClient)
|
||||
channels = new(chmocks.ChannelsServiceClient)
|
||||
domains = new(dmocks.DomainsServiceClient)
|
||||
publisher = new(mocks.PubSub)
|
||||
resolver := messaging.NewTopicResolver(channels, domains)
|
||||
|
||||
return mhttp.NewHandler(publisher, authn, clients, channels, domains, logger)
|
||||
return mhttp.NewHandler(publisher, authn, clients, channels, resolver, logger)
|
||||
}
|
||||
|
||||
func TestAuthConnect(t *testing.T) {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// SubjectAllMessages represents subject to subscribe for all the messages.
|
||||
const SubjectAllMessages = messaging.MsgTopicPrefix + ".>"
|
||||
const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".>"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using Nats as the message broker")
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// SubjectAllMessages represents subject to subscribe for all the messages.
|
||||
const SubjectAllMessages = messaging.MsgTopicPrefix + ".#"
|
||||
const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".#"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using RabbitMQ as the message broker")
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package messaging
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyRouteID = errors.New("empty route or id")
|
||||
ErrFailedResolveDomain = errors.New("failed to resolve domain route")
|
||||
ErrFailedResolveChannel = errors.New("failed to resolve channel route")
|
||||
)
|
||||
|
||||
// TopicResolver contains definitions for resolving domain and channel IDs
|
||||
// from their respective routes from the message topic.
|
||||
type TopicResolver interface {
|
||||
Resolve(ctx context.Context, domain, channel string) (domainID string, channelID string, err error)
|
||||
ResolveTopic(ctx context.Context, topic string) (rtopic string, err error)
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
}
|
||||
|
||||
// NewTopicResolver creates a new instance of TopicResolver.
|
||||
func NewTopicResolver(channelsClient grpcChannelsV1.ChannelsServiceClient, domainsClient grpcDomainsV1.DomainsServiceClient) TopicResolver {
|
||||
return &resolver{
|
||||
channels: channelsClient,
|
||||
domains: domainsClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) Resolve(ctx context.Context, domain, channel string) (string, string, error) {
|
||||
if domain == "" || channel == "" {
|
||||
return "", "", ErrEmptyRouteID
|
||||
}
|
||||
|
||||
domainID, err := r.resolveDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(ErrFailedResolveDomain, err)
|
||||
}
|
||||
channelID, err := r.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return "", "", errors.Wrap(ErrFailedResolveChannel, err)
|
||||
}
|
||||
|
||||
return domainID, channelID, nil
|
||||
}
|
||||
|
||||
func (r *resolver) ResolveTopic(ctx context.Context, topic string) (string, error) {
|
||||
domain, channel, subtopic, err := ParseTopic(topic)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(ErrMalformedTopic, err)
|
||||
}
|
||||
|
||||
domainID, channelID, err := r.Resolve(ctx, domain, channel)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rtopic := EncodeAdapterTopic(domainID, channelID, subtopic)
|
||||
|
||||
return rtopic, nil
|
||||
}
|
||||
|
||||
func (r *resolver) resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if validateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
d, err := r.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: domain,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return d.Entity.Id, nil
|
||||
}
|
||||
|
||||
func (r *resolver) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if validateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
c, err := r.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.Entity.Id, nil
|
||||
}
|
||||
|
||||
func validateUUID(extID string) (err error) {
|
||||
id, err := uuid.FromString(extID)
|
||||
if id.String() != extID || err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package messaging_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
dmocks "github.com/absmach/supermq/domains/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var (
|
||||
validRoute = "valid-route"
|
||||
invalidRoute = "invalid-route"
|
||||
channelID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
topicFmt = "m/%s/c/%s"
|
||||
)
|
||||
|
||||
func setupResolver() (messaging.TopicResolver, *dmocks.DomainsServiceClient, *chmocks.ChannelsServiceClient) {
|
||||
channels := new(chmocks.ChannelsServiceClient)
|
||||
domains := new(dmocks.DomainsServiceClient)
|
||||
resolver := messaging.NewTopicResolver(channels, domains)
|
||||
|
||||
return resolver, domains, channels
|
||||
}
|
||||
|
||||
func TestResolve(t *testing.T) {
|
||||
resolver, domains, channels := setupResolver()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domain string
|
||||
channel string
|
||||
domainID string
|
||||
channelID string
|
||||
domainsErr error
|
||||
channelsErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid domainID and channelID",
|
||||
domain: domainID,
|
||||
channel: channelID,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid domain route and channel ID",
|
||||
domain: validRoute,
|
||||
channel: channelID,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid domain ID and channel route",
|
||||
domain: domainID,
|
||||
channel: validRoute,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid domain route and channel route",
|
||||
domain: validRoute,
|
||||
channel: validRoute,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "invalid domain route and valid channel",
|
||||
domain: invalidRoute,
|
||||
channel: channelID,
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
domainsErr: svcerr.ErrNotFound,
|
||||
err: messaging.ErrFailedResolveDomain,
|
||||
},
|
||||
{
|
||||
desc: "valid domain and invalid channel",
|
||||
domain: domainID,
|
||||
channel: invalidRoute,
|
||||
domainID: domainID,
|
||||
channelID: "",
|
||||
channelsErr: svcerr.ErrNotFound,
|
||||
err: messaging.ErrFailedResolveChannel,
|
||||
},
|
||||
{
|
||||
desc: "empty domain",
|
||||
domain: "",
|
||||
channel: channelID,
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
err: messaging.ErrEmptyRouteID,
|
||||
},
|
||||
{
|
||||
desc: "empty channel",
|
||||
domain: domainID,
|
||||
channel: "",
|
||||
domainID: domainID,
|
||||
channelID: "",
|
||||
err: messaging.ErrEmptyRouteID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
domainsCall := domains.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.domain}).Return(&grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: tc.domainID,
|
||||
},
|
||||
}, tc.domainsErr)
|
||||
channelsCall := channels.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.channel, DomainId: tc.domainID}).Return(&grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: tc.channelID,
|
||||
},
|
||||
}, tc.channelsErr)
|
||||
domainID, channelID, err := resolver.Resolve(context.Background(), tc.domain, tc.channel)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.domainID, domainID, "expected domain ID %s, got %s", tc.domainID, domainID)
|
||||
assert.Equal(t, tc.channelID, channelID, "expected channel ID %s, got %s", tc.channelID, channelID)
|
||||
}
|
||||
domainsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveTopic(t *testing.T) {
|
||||
resolver, domains, channels := setupResolver()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
domain string
|
||||
channel string
|
||||
domainID string
|
||||
channelID string
|
||||
domainsErr error
|
||||
channelsErr error
|
||||
response string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid topic with domainID and channelID",
|
||||
topic: fmt.Sprintf(topicFmt, domainID, channelID),
|
||||
domain: domainID,
|
||||
channel: channelID,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
response: fmt.Sprintf(topicFmt, domainID, channelID),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid topic with domain route and channel ID",
|
||||
topic: fmt.Sprintf(topicFmt, validRoute, channelID),
|
||||
domain: validRoute,
|
||||
channel: channelID,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
response: fmt.Sprintf(topicFmt, domainID, channelID),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid topic with domain ID and channel route",
|
||||
topic: fmt.Sprintf(topicFmt, domainID, validRoute),
|
||||
domain: domainID,
|
||||
channel: validRoute,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
response: fmt.Sprintf(topicFmt, domainID, channelID),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid topic with domain route and channel route",
|
||||
topic: fmt.Sprintf(topicFmt, validRoute, validRoute),
|
||||
domain: validRoute,
|
||||
channel: validRoute,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
response: fmt.Sprintf(topicFmt, domainID, channelID),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic with invalid domain route and valid channel",
|
||||
topic: fmt.Sprintf(topicFmt, invalidRoute, channelID),
|
||||
domain: invalidRoute,
|
||||
channel: channelID,
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
domainsErr: svcerr.ErrNotFound,
|
||||
err: messaging.ErrFailedResolveDomain,
|
||||
},
|
||||
{
|
||||
desc: "valid topic with valid topic with domainID and channelID and subtopic",
|
||||
topic: fmt.Sprintf(topicFmt, domainID, channelID) + "/subtopic",
|
||||
domain: domainID,
|
||||
channel: channelID,
|
||||
domainID: domainID,
|
||||
channelID: channelID,
|
||||
response: fmt.Sprintf(topicFmt, domainID, channelID) + "/subtopic",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic with empty domain",
|
||||
topic: fmt.Sprintf(topicFmt, "", channelID),
|
||||
domain: "",
|
||||
channel: channelID,
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
domainsCall := domains.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.domain}).Return(&grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: tc.domainID,
|
||||
},
|
||||
}, tc.domainsErr)
|
||||
channelsCall := channels.On("RetrieveByRoute", mock.Anything, &grpcCommonV1.RetrieveByRouteReq{Route: tc.channel, DomainId: tc.domainID}).Return(&grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: tc.channelID,
|
||||
},
|
||||
}, tc.channelsErr)
|
||||
rtopic, err := resolver.ResolveTopic(context.Background(), tc.topic)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.response, rtopic, "expected topic %s, got %s", tc.response, rtopic)
|
||||
}
|
||||
domainsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
+86
-28
@@ -6,27 +6,19 @@ package messaging
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
MsgTopicPrefix = "m"
|
||||
ChannelTopicPrefix = "c"
|
||||
|
||||
numGroups = 4 // entire expression + domain group + channel group + subtopic group
|
||||
domainGroup = 1 // domain group is first in msg topic regexp
|
||||
channelGroup = 2 // channel group is second in msg topic regexp
|
||||
subtopicGroup = 3 // subtopic group is third in msg topic regexp
|
||||
MsgTopicPrefix = 'm'
|
||||
ChannelTopicPrefix = 'c'
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMalformedTopic = errors.New("malformed topic")
|
||||
ErrMalformedSubtopic = errors.New("malformed subtopic")
|
||||
// Regex to group topic in format m.<domain_id>.c.<channel_id>.<sub_topic> `^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`.
|
||||
TopicRegExp = regexp.MustCompile(`^\/?` + MsgTopicPrefix + `\/([\w\-]+)\/` + ChannelTopicPrefix + `\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
|
||||
mqWildcards = "+#"
|
||||
wildcards = "*>"
|
||||
subtopicInvalidChars = " #+"
|
||||
@@ -35,15 +27,10 @@ var (
|
||||
)
|
||||
|
||||
func ParsePublishTopic(topic string) (domainID, chanID, subtopic string, err error) {
|
||||
msgParts := TopicRegExp.FindStringSubmatch(topic)
|
||||
if len(msgParts) < numGroups {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
domainID, chanID, subtopic, err = ParseTopic(topic)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
domainID = msgParts[domainGroup]
|
||||
chanID = msgParts[channelGroup]
|
||||
subtopic = msgParts[subtopicGroup]
|
||||
|
||||
subtopic, err = ParsePublishSubtopic(subtopic)
|
||||
if err != nil {
|
||||
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
|
||||
@@ -74,14 +61,10 @@ func ParsePublishSubtopic(subtopic string) (parseSubTopic string, err error) {
|
||||
}
|
||||
|
||||
func ParseSubscribeTopic(topic string) (domainID string, chanID string, subtopic string, err error) {
|
||||
msgParts := TopicRegExp.FindStringSubmatch(topic)
|
||||
if len(msgParts) < numGroups {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
domainID, chanID, subtopic, err = ParseTopic(topic)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
domainID = msgParts[domainGroup]
|
||||
chanID = msgParts[channelGroup]
|
||||
subtopic = msgParts[subtopicGroup]
|
||||
subtopic, err = ParseSubscribeSubtopic(subtopic)
|
||||
if err != nil {
|
||||
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
|
||||
@@ -132,11 +115,11 @@ func formatSubtopic(subtopic string) (string, error) {
|
||||
}
|
||||
|
||||
func EncodeTopic(domainID string, channelID string, subtopic string) string {
|
||||
return fmt.Sprintf("%s.%s", MsgTopicPrefix, EncodeTopicSuffix(domainID, channelID, subtopic))
|
||||
return fmt.Sprintf("%s.%s", string(MsgTopicPrefix), EncodeTopicSuffix(domainID, channelID, subtopic))
|
||||
}
|
||||
|
||||
func EncodeTopicSuffix(domainID string, channelID string, subtopic string) string {
|
||||
subject := fmt.Sprintf("%s.%s.%s", domainID, ChannelTopicPrefix, channelID)
|
||||
subject := fmt.Sprintf("%s.%s.%s", domainID, string(ChannelTopicPrefix), channelID)
|
||||
if subtopic != "" {
|
||||
subject = fmt.Sprintf("%s.%s", subject, subtopic)
|
||||
}
|
||||
@@ -148,9 +131,84 @@ func EncodeMessageTopic(m *Message) string {
|
||||
}
|
||||
|
||||
func EncodeMessageMQTTTopic(m *Message) string {
|
||||
topic := fmt.Sprintf("%s/%s/%s/%s", MsgTopicPrefix, m.GetDomain(), ChannelTopicPrefix, m.GetChannel())
|
||||
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), m.GetDomain(), string(ChannelTopicPrefix), m.GetChannel())
|
||||
if m.GetSubtopic() != "" {
|
||||
topic = topic + "/" + strings.ReplaceAll(m.GetSubtopic(), ".", "/")
|
||||
}
|
||||
return topic
|
||||
}
|
||||
|
||||
func EncodeAdapterTopic(domain, channel, subtopic string) string {
|
||||
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), domain, string(ChannelTopicPrefix), channel)
|
||||
if subtopic != "" {
|
||||
topic = topic + "/" + subtopic
|
||||
}
|
||||
return topic
|
||||
}
|
||||
|
||||
// ParseTopic parses a messaging topic string and returns the domain ID, channel ID, and subtopic.
|
||||
// This is an optimized version with no regex and minimal allocations.
|
||||
func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
|
||||
// location of string "m"
|
||||
start := 0
|
||||
// Handle both formats: "/m/domain/c/channel/subtopic" and "m/domain/c/channel/subtopic".
|
||||
// If topic start with m/ then start is 0 , If topic start with /m/ then start is 1.
|
||||
n := len(topic)
|
||||
if n > 0 && topic[0] == '/' {
|
||||
start = 1
|
||||
}
|
||||
|
||||
// length check - minimum: "m/<domain_id>/c/" = 5 characters if ignore <domain_id> and in this case start will be 0
|
||||
// length check - minimum: "/m/<domain_id>/c/" = 6 characters if ignore <domain_id> and in this case start will be 1
|
||||
if n < start+5 {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
}
|
||||
if topic[start] != MsgTopicPrefix || topic[start+1] != '/' {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
}
|
||||
pos := start + 2
|
||||
|
||||
// Find "/c/" to locate domain ID
|
||||
cPos := -1
|
||||
for i := pos; i <= n-3; i++ {
|
||||
if topic[i] == '/' && topic[i+1] == ChannelTopicPrefix && topic[i+2] == '/' {
|
||||
cPos = i - pos
|
||||
break
|
||||
}
|
||||
}
|
||||
if cPos == -1 || cPos == 0 {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
}
|
||||
domainID = topic[pos : pos+cPos]
|
||||
// skip "/c/"
|
||||
pos = pos + cPos + 3
|
||||
|
||||
// Ensure channel exists
|
||||
if pos >= n {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
}
|
||||
|
||||
// Find '/' after channelID
|
||||
nextSlash := -1
|
||||
for i := pos; i < n; i++ {
|
||||
if topic[i] == '/' {
|
||||
nextSlash = i - pos
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if nextSlash == -1 {
|
||||
// No subtopic
|
||||
chanID = topic[pos:]
|
||||
} else {
|
||||
chanID = topic[pos : pos+nextSlash]
|
||||
subtopic = topic[pos+nextSlash+1:]
|
||||
}
|
||||
|
||||
// Validate channelID
|
||||
if len(chanID) == 0 {
|
||||
return "", "", "", ErrMalformedTopic
|
||||
}
|
||||
|
||||
return domainID, chanID, subtopic, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package messaging_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -16,7 +17,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID string
|
||||
channelID string
|
||||
subtopic string
|
||||
expectErr bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid topic with subtopic /m/domain123/c/channel456/devices/temp",
|
||||
@@ -24,6 +25,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "devices.temp",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid topic with URL encoded subtopic /m/domain123/c/channel456/devices%2Ftemp%2Fdata",
|
||||
@@ -46,13 +48,20 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
},
|
||||
{
|
||||
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
|
||||
topic: "/m/domain123/c/channel456/devices/temp/",
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "devices.temp",
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing parts) /m/domain123/c/",
|
||||
topic: "/m/domain123/c/",
|
||||
domainID: "domain123",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing domain) /m//c/channel123",
|
||||
@@ -60,7 +69,15 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing channel) /m/domain123/c/",
|
||||
topic: "/m/domain123/c//subtopic",
|
||||
domainID: "domain123",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "topic with wildcards + and # /m/domain123/c/channel456/devices/+/temp/#",
|
||||
@@ -68,7 +85,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid domain name m/domain*123/c/channel456/devices/+/temp/#",
|
||||
@@ -76,7 +93,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "",
|
||||
channelID: "channel456",
|
||||
subtopic: "devices.*.temp.>",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a*b/topic",
|
||||
@@ -84,7 +101,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a>b/topic",
|
||||
@@ -92,7 +109,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a#b/topic",
|
||||
@@ -100,7 +117,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a+b/topic",
|
||||
@@ -108,7 +125,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a//b/topic",
|
||||
@@ -116,7 +133,7 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic regex \"not-a-topic\"",
|
||||
@@ -124,7 +141,12 @@ var ParsePublisherTopicTestCases = []struct {
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "extra segment before prefix /extra/m/domain/c/channel",
|
||||
topic: "/extra/m/domain/c/channel",
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -132,10 +154,8 @@ func TestParsePublishTopic(t *testing.T) {
|
||||
for _, tc := range ParsePublisherTopicTestCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
domainID, channelID, subtopic, err := messaging.ParsePublishTopic(tc.topic)
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.domainID, domainID)
|
||||
assert.Equal(t, tc.channelID, channelID)
|
||||
assert.Equal(t, tc.subtopic, subtopic)
|
||||
@@ -147,7 +167,7 @@ func TestParsePublishTopic(t *testing.T) {
|
||||
func BenchmarkParsePublisherTopic(b *testing.B) {
|
||||
for _, tc := range ParsePublisherTopicTestCases {
|
||||
b.Run(tc.desc, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for b.Loop() {
|
||||
_, _, _, _ = messaging.ParsePublishTopic(tc.topic)
|
||||
}
|
||||
})
|
||||
@@ -160,7 +180,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID string
|
||||
channelID string
|
||||
subtopic string
|
||||
expectErr bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid topic with subtopic /m/domain123/c/channel456/devices/temp",
|
||||
@@ -183,13 +203,20 @@ var ParseSubscribeTestCases = []struct {
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
},
|
||||
{
|
||||
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
|
||||
topic: "/m/domain123/c/channel456/devices/temp/",
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "devices.temp",
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing channel) /m/domain123/c/",
|
||||
topic: "/m/domain123/c/",
|
||||
domainID: "domain123",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing domain) /m//c/channel123",
|
||||
@@ -197,15 +224,22 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid topic format (missing channel) /m/domain123/c/",
|
||||
topic: "/m/domain123/c//subtopic",
|
||||
domainID: "domain123",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid domain name m/domain*123/c/channel456/devices/+/temp/#",
|
||||
topic: "m/domain*123/c/channel456/devices/+/temp/#",
|
||||
domainID: "",
|
||||
domainID: "domain*123",
|
||||
channelID: "channel456",
|
||||
subtopic: "devices.*.temp.>",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a*b/topic",
|
||||
@@ -213,7 +247,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a>b/topic",
|
||||
@@ -221,7 +255,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a#b/topic",
|
||||
@@ -229,7 +263,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a+b/topic",
|
||||
@@ -237,7 +271,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a//b/topic",
|
||||
@@ -245,7 +279,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "invalid subtopic /m/domain123/c/channel456/sub/a/ /b/topic",
|
||||
@@ -253,7 +287,7 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "domain123",
|
||||
channelID: "channel456",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "completely invalid topic \"invalid-topic\"",
|
||||
@@ -261,7 +295,12 @@ var ParseSubscribeTestCases = []struct {
|
||||
domainID: "",
|
||||
channelID: "",
|
||||
subtopic: "",
|
||||
expectErr: true,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "extra segment before prefix /extra/m/domain/c/channel",
|
||||
topic: "/extra/m/domain/c/channel",
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -269,10 +308,8 @@ func TestParseSubscribeTopic(t *testing.T) {
|
||||
for _, tc := range ParseSubscribeTestCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
domainID, channelID, subtopic, err := messaging.ParseSubscribeTopic(tc.topic)
|
||||
if tc.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.domainID, domainID)
|
||||
assert.Equal(t, tc.channelID, channelID)
|
||||
assert.Equal(t, tc.subtopic, subtopic)
|
||||
@@ -284,7 +321,7 @@ func TestParseSubscribeTopic(t *testing.T) {
|
||||
func BenchmarkParseSubscribeTopic(b *testing.B) {
|
||||
for _, tc := range ParseSubscribeTestCases {
|
||||
b.Run(tc.desc, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for b.Loop() {
|
||||
_, _, _, _ = messaging.ParseSubscribeTopic(tc.topic)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
pubsub "github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
sdk "github.com/absmach/supermq/pkg/sdk"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -46,7 +47,8 @@ func setupMessages() (*httptest.Server, *pubsub.PubSub) {
|
||||
domainsGRPCClient = new(dmocks.DomainsServiceClient)
|
||||
pub := new(pubsub.PubSub)
|
||||
authn := new(authnmocks.Authentication)
|
||||
handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, domainsGRPCClient, smqlog.NewMock())
|
||||
resolver := messaging.NewTopicResolver(channelsGRPCClient, domainsGRPCClient)
|
||||
handler := adapter.NewHandler(pub, authn, clientsGRPCClient, channelsGRPCClient, resolver, smqlog.NewMock())
|
||||
|
||||
mux := api.MakeHandler(smqlog.NewMock(), "")
|
||||
target := httptest.NewServer(mux)
|
||||
|
||||
+4
-59
@@ -9,9 +9,6 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -26,10 +23,6 @@ var (
|
||||
ErrFailedSubscribe = errors.New("failed to unsubscribe from topic")
|
||||
// ErrEmptyTopic indicate absence of clientKey in the request.
|
||||
ErrEmptyTopic = errors.New("empty topic")
|
||||
// errFailedResolveDomain indicates that the domain route could not be resolved.
|
||||
errFailedResolveDomain = errors.New("failed to resolve domain route")
|
||||
// errFailedResolveChannel indicates that the channel route could not be resolved.
|
||||
errFailedResolveChannel = errors.New("failed to resolve channel route")
|
||||
)
|
||||
|
||||
// Service specifies web socket service API.
|
||||
@@ -48,34 +41,23 @@ var _ Service = (*adapterService)(nil)
|
||||
type adapterService struct {
|
||||
clients grpcClientsV1.ClientsServiceClient
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
pubsub messaging.PubSub
|
||||
}
|
||||
|
||||
// New instantiates the WS adapter implementation.
|
||||
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, pubsub messaging.PubSub) Service {
|
||||
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service {
|
||||
return &adapterService{
|
||||
clients: clients,
|
||||
channels: channels,
|
||||
domains: domains,
|
||||
pubsub: pubsub,
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domain, channel, subtopic string, c *Client) error {
|
||||
if channel == "" || clientKey == "" || domain == "" {
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, channelID, subtopic string, c *Client) error {
|
||||
if channelID == "" || clientKey == "" || domainID == "" {
|
||||
return svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
domainID, err := svc.resolveDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return errFailedResolveDomain
|
||||
}
|
||||
channelID, err := svc.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return errFailedResolveChannel
|
||||
}
|
||||
|
||||
clientID, err := svc.authorize(ctx, clientKey, domainID, channelID, connections.Subscribe)
|
||||
if err != nil {
|
||||
return svcerr.ErrAuthorization
|
||||
@@ -97,15 +79,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domain, channel, subtopic string) error {
|
||||
domainID, err := svc.resolveDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedResolveDomain, err)
|
||||
}
|
||||
channelID, err := svc.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedResolveChannel, err)
|
||||
}
|
||||
func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID, channelID, subtopic string) error {
|
||||
topic := messaging.EncodeTopic(domainID, channelID, subtopic)
|
||||
|
||||
if err := svc.pubsub.Unsubscribe(ctx, sessionID, topic); err != nil {
|
||||
@@ -148,32 +122,3 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c
|
||||
|
||||
return authnRes.GetId(), nil
|
||||
}
|
||||
|
||||
func (svc *adapterService) resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if api.ValidateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
d, err := svc.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: domain,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return d.Entity.Id, nil
|
||||
}
|
||||
|
||||
func (svc *adapterService) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if api.ValidateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
c, err := svc.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.Entity.Id, nil
|
||||
}
|
||||
|
||||
+3
-10
@@ -12,10 +12,8 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
climocks "github.com/absmach/supermq/clients/mocks"
|
||||
dmocks "github.com/absmach/supermq/domains/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -52,17 +50,16 @@ var (
|
||||
sessionID = "sessionID"
|
||||
)
|
||||
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *dmocks.DomainsServiceClient) {
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
|
||||
pubsub := new(mocks.PubSub)
|
||||
clients := new(climocks.ClientsServiceClient)
|
||||
channels := new(chmocks.ChannelsServiceClient)
|
||||
domains := new(dmocks.DomainsServiceClient)
|
||||
|
||||
return ws.New(clients, channels, domains, pubsub), pubsub, clients, channels, domains
|
||||
return ws.New(clients, channels, pubsub), pubsub, clients, channels
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
svc, pubsub, clients, channels, domains := newService()
|
||||
svc, pubsub, clients, channels := newService()
|
||||
|
||||
c := ws.NewClient(slog.Default(), nil, sessionID)
|
||||
|
||||
@@ -197,7 +194,6 @@ func TestSubscribe(t *testing.T) {
|
||||
if strings.HasPrefix(tc.clientKey, "Client") {
|
||||
authReq.ClientSecret = strings.TrimPrefix(tc.clientKey, "Client ")
|
||||
}
|
||||
domainsCall := domains.On("RetrieveByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: policies.ClientType,
|
||||
@@ -206,14 +202,11 @@ func TestSubscribe(t *testing.T) {
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
channelsCall1 := channels.On("RetrieveByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.chanID}}, nil)
|
||||
repocall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repocall.Unset()
|
||||
clientsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
domainsCall.Unset()
|
||||
channelsCall1.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
climocks "github.com/absmach/supermq/clients/mocks"
|
||||
dmocks "github.com/absmach/supermq/domains/mocks"
|
||||
@@ -26,6 +25,7 @@ import (
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
"github.com/absmach/supermq/ws"
|
||||
"github.com/absmach/supermq/ws/api"
|
||||
@@ -47,13 +47,13 @@ var (
|
||||
id = testsutil.GenerateUUID(&testing.T{})
|
||||
)
|
||||
|
||||
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (ws.Service, *mocks.PubSub) {
|
||||
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
|
||||
pubsub := new(mocks.PubSub)
|
||||
return ws.New(clients, channels, domains, pubsub), pubsub
|
||||
return ws.New(clients, channels, pubsub), pubsub
|
||||
}
|
||||
|
||||
func newHTTPServer(svc ws.Service) *httptest.Server {
|
||||
mux := api.MakeHandler(context.Background(), svc, smqlog.NewMock(), instanceID)
|
||||
func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server {
|
||||
mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID)
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
|
||||
@@ -116,10 +116,11 @@ func TestHandshake(t *testing.T) {
|
||||
channels := new(chmocks.ChannelsServiceClient)
|
||||
authn := new(authnMocks.Authentication)
|
||||
domains := new(dmocks.DomainsServiceClient)
|
||||
svc, pubsub := newService(clients, channels, domains)
|
||||
target := newHTTPServer(svc)
|
||||
resolver := messaging.NewTopicResolver(channels, domains)
|
||||
svc, pubsub := newService(clients, channels)
|
||||
target := newHTTPServer(svc, resolver)
|
||||
defer target.Close()
|
||||
handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, domains)
|
||||
handler := ws.NewHandler(pubsub, smqlog.NewMock(), authn, clients, channels, resolver)
|
||||
ts, err := newProxyHTPPServer(handler, target)
|
||||
require.Nil(t, err)
|
||||
defer ts.Close()
|
||||
|
||||
+13
-8
@@ -27,9 +27,9 @@ func generateSessionID() (string, error) {
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.HandlerFunc {
|
||||
func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
req, err := decodeRequest(r, logger)
|
||||
req, err := decodeRequest(r, resolver, logger)
|
||||
if err != nil {
|
||||
encodeError(w, err)
|
||||
return
|
||||
@@ -51,21 +51,21 @@ func handshake(ctx context.Context, svc ws.Service, logger *slog.Logger) http.Ha
|
||||
client := ws.NewClient(logger, conn, sessionID)
|
||||
|
||||
client.SetCloseHandler(func(code int, text string) error {
|
||||
return svc.Unsubscribe(ctx, sessionID, req.domain, req.channel, req.subtopic)
|
||||
return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic)
|
||||
})
|
||||
|
||||
go client.Start(ctx)
|
||||
|
||||
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domain, req.channel, req.subtopic, client); err != nil {
|
||||
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channel))
|
||||
logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID))
|
||||
}
|
||||
}
|
||||
|
||||
func decodeRequest(r *http.Request, logger *slog.Logger) (connReq, error) {
|
||||
func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
|
||||
authKey := r.Header.Get("Authorization")
|
||||
if authKey == "" {
|
||||
authKeys := r.URL.Query()["authorization"]
|
||||
@@ -79,10 +79,15 @@ func decodeRequest(r *http.Request, logger *slog.Logger) (connReq, error) {
|
||||
domain := chi.URLParam(r, "domain")
|
||||
channel := chi.URLParam(r, "channel")
|
||||
|
||||
domainID, channelID, err := resolver.Resolve(r.Context(), domain, channel)
|
||||
if err != nil {
|
||||
return connReq{}, err
|
||||
}
|
||||
|
||||
req := connReq{
|
||||
clientKey: authKey,
|
||||
channel: channel,
|
||||
domain: domain,
|
||||
channelID: channelID,
|
||||
domainID: domainID,
|
||||
}
|
||||
|
||||
subTopic := chi.URLParam(r, "*")
|
||||
|
||||
+2
-2
@@ -5,7 +5,7 @@ package api
|
||||
|
||||
type connReq struct {
|
||||
clientKey string
|
||||
channel string
|
||||
domain string
|
||||
channelID string
|
||||
domainID string
|
||||
subtopic string
|
||||
}
|
||||
|
||||
+4
-3
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/ws"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -36,12 +37,12 @@ var (
|
||||
)
|
||||
|
||||
// MakeHandler returns http handler with handshake endpoint.
|
||||
func MakeHandler(ctx context.Context, svc ws.Service, l *slog.Logger, instanceID string) http.Handler {
|
||||
func MakeHandler(ctx context.Context, svc ws.Service, resolver messaging.TopicResolver, l *slog.Logger, instanceID string) http.Handler {
|
||||
logger = l
|
||||
|
||||
mux := chi.NewRouter()
|
||||
mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, l))
|
||||
mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, l))
|
||||
mux.Get("/m/{domain}/c/{channel}", handshake(ctx, svc, resolver, l))
|
||||
mux.Get("/m/{domain}/c/{channel}/*", handshake(ctx, svc, resolver, l))
|
||||
|
||||
mux.Get("/health", supermq.Health(service, instanceID))
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
+12
-57
@@ -15,9 +15,6 @@ import (
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
@@ -53,20 +50,20 @@ type handler struct {
|
||||
pubsub messaging.PubSub
|
||||
clients grpcClientsV1.ClientsServiceClient
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
domains grpcDomainsV1.DomainsServiceClient
|
||||
authn smqauthn.Authentication
|
||||
logger *slog.Logger
|
||||
resolver messaging.TopicResolver
|
||||
}
|
||||
|
||||
// NewHandler creates new Handler entity.
|
||||
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) session.Handler {
|
||||
func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, resolver messaging.TopicResolver) session.Handler {
|
||||
return &handler{
|
||||
logger: logger,
|
||||
pubsub: pubsub,
|
||||
authn: authn,
|
||||
clients: clients,
|
||||
channels: channels,
|
||||
domains: domains,
|
||||
resolver: resolver,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,18 +94,14 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
|
||||
|
||||
domain, channel, _, err := messaging.ParsePublishTopic(*topic)
|
||||
if err != nil {
|
||||
return err
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
domainID, err := h.resolveDomain(ctx, domain)
|
||||
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
|
||||
}
|
||||
chanID, err := h.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
|
||||
clientID, clientType, err := h.authAccess(ctx, token, domainID, chanID, connections.Publish)
|
||||
clientID, clientType, err := h.authAccess(ctx, token, domainID, channelID, connections.Publish)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -136,19 +129,14 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
domainID, err := h.resolveDomain(ctx, domain)
|
||||
domainID, chanID, err := h.resolver.Resolve(ctx, domain, channel)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
|
||||
}
|
||||
chanID, err := h.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
if _, _, err := h.authAccess(ctx, string(s.Password), domainID, chanID, connections.Subscribe); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -173,19 +161,15 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedPublish, err)
|
||||
}
|
||||
domainID, err := h.resolveDomain(ctx, domain)
|
||||
domainID, channelID, err := h.resolver.Resolve(ctx, domain, channel)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveDomain, err))
|
||||
}
|
||||
chanID, err := h.resolveChannel(ctx, channel, domainID)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedResolveChannel, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
|
||||
msg := messaging.Message{
|
||||
Protocol: protocol,
|
||||
Domain: domainID,
|
||||
Channel: chanID,
|
||||
Channel: channelID,
|
||||
Subtopic: subtopic,
|
||||
Payload: *payload,
|
||||
Publisher: s.Username,
|
||||
@@ -257,35 +241,6 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
|
||||
return clientID, clientType, nil
|
||||
}
|
||||
|
||||
func (h *handler) resolveDomain(ctx context.Context, domain string) (string, error) {
|
||||
if api.ValidateUUID(domain) == nil {
|
||||
return domain, nil
|
||||
}
|
||||
d, err := h.domains.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: domain,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return d.Entity.Id, nil
|
||||
}
|
||||
|
||||
func (h *handler) resolveChannel(ctx context.Context, channel, domainID string) (string, error) {
|
||||
if api.ValidateUUID(channel) == nil {
|
||||
return channel, nil
|
||||
}
|
||||
c, err := h.channels.RetrieveByRoute(ctx, &grpcCommonV1.RetrieveByRouteReq{
|
||||
Route: channel,
|
||||
DomainId: domainID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return c.Entity.Id, nil
|
||||
}
|
||||
|
||||
// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned.
|
||||
func extractClientSecret(token string) string {
|
||||
if !strings.HasPrefix(token, apiutil.ClientPrefix) {
|
||||
|
||||
Reference in New Issue
Block a user