SMQ-2801 - Add health check endpoint for MQTT adapter (#3024)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-09-12 12:37:29 +03:00
committed by GitHub
parent 1d935f0519
commit ad191c4609
8 changed files with 255 additions and 87 deletions
+2 -2
View File
@@ -154,9 +154,9 @@ func (h *CoAPHandler) decodeMessage(msg *mux.Message) (*messaging.Message, error
var domainID, channelID, subTopic string
switch msg.Code() {
case codes.GET:
domainID, channelID, subTopic, err = h.parser.ParseSubscribeTopic(msg.Context(), path, true)
domainID, channelID, subTopic, _, err = h.parser.ParseSubscribeTopic(msg.Context(), path, true)
case codes.POST:
domainID, channelID, subTopic, err = h.parser.ParsePublishTopic(msg.Context(), path, true)
domainID, channelID, subTopic, _, err = h.parser.ParsePublishTopic(msg.Context(), path, true)
}
if err != nil {
return &messaging.Message{}, err
+1 -1
View File
@@ -113,7 +113,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(errFailedPublish, errClientNotInitialized)
}
domainID, channelID, subtopic, err := h.parser.ParsePublishTopic(ctx, *topic, true)
domainID, channelID, subtopic, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return errors.Wrap(errMalformedTopic, err)
}
+1 -1
View File
@@ -90,7 +90,7 @@ func (es *eventStore) Subscribe(ctx context.Context, topics *[]string) error {
}
for _, topic := range *topics {
domainID, channelID, subTopic, err := messaging.ParseSubscribeTopic(topic)
domainID, channelID, subTopic, _, err := messaging.ParseSubscribeTopic(topic)
if err != nil {
return err
}
+30 -23
View File
@@ -113,12 +113,12 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
return ErrClientNotInitialized
}
domainID, chanID, _, err := h.parser.ParsePublishTopic(ctx, *topic, false)
domainID, chanID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, false)
if err != nil {
return err
}
return h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Publish)
return h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Publish, topicType)
}
// AuthSubscribe is called on device subscribe,
@@ -133,12 +133,12 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
}
for _, topic := range *topics {
domainID, chanID, _, err := h.parser.ParseSubscribeTopic(ctx, topic, false)
domainID, chanID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, false)
if err != nil {
return err
}
if err := h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Subscribe); err != nil {
if err := h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Subscribe, topicType); err != nil {
return err
}
}
@@ -164,7 +164,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
}
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
domainID, chanID, subTopic, err := h.parser.ParsePublishTopic(ctx, *topic, false)
domainID, chanID, subTopic, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, false)
if err != nil {
return errors.Wrap(ErrFailedPublish, err)
}
@@ -179,8 +179,10 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
Created: time.Now().UnixNano(),
}
if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return errors.Wrap(ErrFailedPublishToMsgBroker, err)
if topicType == messaging.MessageType {
if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return errors.Wrap(ErrFailedPublishToMsgBroker, err)
}
}
return nil
@@ -219,21 +221,26 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, clientID, domainID, chanID string, msgType connections.ConnType) error {
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: clientID,
ClientType: policies.ClientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return err
}
if !res.GetAuthorized() {
return svcerr.ErrAuthorization
}
func (h *handler) authAccess(ctx context.Context, clientID, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) error {
switch topicType {
case messaging.HealthType:
return nil
default:
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: clientID,
ClientType: policies.ClientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return err
}
if !res.GetAuthorized() {
return svcerr.ErrAuthorization
}
return nil
return nil
}
}
+46
View File
@@ -46,6 +46,9 @@ var (
domainID = testsutil.GenerateUUID(&testing.T{})
topicMsg = "/m/%s/c/%s"
topic = fmt.Sprintf(topicMsg, domainID, chanID)
hcTopicFmt = "/hc/%s"
hcTopic = fmt.Sprintf(hcTopicFmt, domainID)
invalidHCTopic = "/hc"
invalidTopic = invalidValue
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
topics = []string{topic}
@@ -210,6 +213,21 @@ func TestAuthPublish(t *testing.T) {
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
},
{
desc: "publish to health check topic",
session: &sessionClient,
err: nil,
topic: &hcTopic,
payload: payload,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
},
{
desc: "publich with invalid health check topic",
session: &sessionClient,
err: messaging.ErrMalformedTopic,
topic: &invalidHCTopic,
payload: payload,
},
}
for _, tc := range cases {
@@ -286,6 +304,20 @@ func TestAuthSubscribe(t *testing.T) {
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
channelID: chanID,
},
{
desc: "subscribe successfully with health check topic",
session: &sessionClientSub,
err: nil,
topic: &[]string{hcTopic},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
channelID: "",
},
{
desc: "subscribe with invalid health check topic",
session: &sessionClientSub,
err: messaging.ErrMalformedTopic,
topic: &[]string{invalidHCTopic},
},
}
for _, tc := range cases {
@@ -408,6 +440,20 @@ func TestPublish(t *testing.T) {
payload: payload,
logMsg: "",
},
{
desc: "publish with health check topic",
session: &sessionClient,
topic: hcTopic,
payload: payload,
logMsg: "",
},
{
desc: "publish with invalid health check topic",
session: &sessionClient,
topic: invalidHCTopic,
payload: payload,
err: errors.Wrap(mqtt.ErrFailedPublish, messaging.ErrMalformedTopic),
},
}
for _, tc := range cases {
+83 -46
View File
@@ -43,6 +43,14 @@ var (
ErrCreateCache = errors.New("failed to create cache")
)
type TopicType uint8
const (
InvalidType TopicType = iota
MessageType
HealthType
)
type CacheConfig struct {
NumCounters int64 `env:"NUM_COUNTERS" envDefault:"200000"` // number of keys to track frequency of.
MaxCost int64 `env:"MAX_COST" envDefault:"1048576"` // maximum cost of cache.
@@ -60,8 +68,8 @@ type parsedTopic struct {
// It uses a cache to store parsed topics for quick retrieval.
// It also resolves domain and channel IDs if requested.
type TopicParser interface {
ParsePublishTopic(ctx context.Context, topic string, resolve bool) (domainID, channelID, subtopic string, err error)
ParseSubscribeTopic(ctx context.Context, topic string, resolve bool) (domainID, channelID, subtopic string, err error)
ParsePublishTopic(ctx context.Context, topic string, resolve bool) (domainID, channelID, subtopic string, topicType TopicType, err error)
ParseSubscribeTopic(ctx context.Context, topic string, resolve bool) (domainID, channelID, subtopic string, topicType TopicType, err error)
}
type parser struct {
@@ -86,43 +94,43 @@ func NewTopicParser(cfg CacheConfig, channels grpcChannelsV1.ChannelsServiceClie
}, nil
}
func (p *parser) ParsePublishTopic(ctx context.Context, topic string, resolve bool) (string, string, string, error) {
func (p *parser) ParsePublishTopic(ctx context.Context, topic string, resolve bool) (string, string, string, TopicType, error) {
val, ok := p.cache.Get(topic)
if ok {
return val.domainID, val.channelID, val.subtopic, val.err
return val.domainID, val.channelID, val.subtopic, MessageType, val.err
}
domainID, channelID, subtopic, err := ParsePublishTopic(topic)
domainID, channelID, subtopic, topicType, err := ParsePublishTopic(topic)
if err != nil {
p.saveToCache(topic, "", "", "", err)
return "", "", "", err
return "", "", "", InvalidType, err
}
var isRoute bool
if resolve {
domainID, channelID, isRoute, err = p.resolver.Resolve(ctx, domainID, channelID)
if err != nil {
return "", "", "", err
return "", "", "", InvalidType, err
}
}
if !isRoute {
if !isRoute && topicType == MessageType {
p.saveToCache(topic, domainID, channelID, subtopic, nil)
}
return domainID, channelID, subtopic, nil
return domainID, channelID, subtopic, topicType, nil
}
func (p *parser) ParseSubscribeTopic(ctx context.Context, topic string, resolve bool) (string, string, string, error) {
domainID, channelID, subtopic, err := ParseSubscribeTopic(topic)
func (p *parser) ParseSubscribeTopic(ctx context.Context, topic string, resolve bool) (string, string, string, TopicType, error) {
domainID, channelID, subtopic, topicType, err := ParseSubscribeTopic(topic)
if err != nil {
return "", "", "", err
return "", "", "", InvalidType, err
}
if resolve {
domainID, channelID, _, err = p.resolver.Resolve(ctx, domainID, channelID)
if err != nil {
return "", "", "", err
return "", "", "", InvalidType, err
}
}
return domainID, channelID, subtopic, nil
return domainID, channelID, subtopic, topicType, nil
}
func (p *parser) saveToCache(topic string, domainID, channelID, subtopic string, err error) {
@@ -165,25 +173,28 @@ func NewTopicResolver(channelsClient grpcChannelsV1.ChannelsServiceClient, domai
}
func (r *resolver) Resolve(ctx context.Context, domain, channel string) (string, string, bool, error) {
if domain == "" || channel == "" {
if domain == "" {
return "", "", false, ErrEmptyRouteID
}
domainID, isdomainRoute, err := r.resolveDomain(ctx, domain)
domainID, isDomainRoute, err := r.resolveDomain(ctx, domain)
if err != nil {
return "", "", false, errors.Wrap(ErrFailedResolveDomain, err)
}
if channel == "" {
return domainID, "", isDomainRoute, nil
}
channelID, isChannelRoute, err := r.resolveChannel(ctx, channel, domainID)
if err != nil {
return "", "", false, errors.Wrap(ErrFailedResolveChannel, err)
}
isRoute := isdomainRoute || isChannelRoute
isRoute := isDomainRoute || isChannelRoute
return domainID, channelID, isRoute, nil
}
func (r *resolver) ResolveTopic(ctx context.Context, topic string) (string, error) {
domain, channel, subtopic, err := ParseTopic(topic)
domain, channel, subtopic, topicType, err := ParseTopic(topic)
if err != nil {
return "", errors.Wrap(ErrMalformedTopic, err)
}
@@ -192,7 +203,7 @@ func (r *resolver) ResolveTopic(ctx context.Context, topic string) (string, erro
if err != nil {
return "", err
}
rtopic := EncodeAdapterTopic(domainID, channelID, subtopic)
rtopic := encodeAdapterTopic(domainID, channelID, subtopic, topicType)
return rtopic, nil
}
@@ -235,17 +246,17 @@ func validateUUID(extID string) (err error) {
return nil
}
func ParsePublishTopic(topic string) (domainID, chanID, subtopic string, err error) {
domainID, chanID, subtopic, err = ParseTopic(topic)
func ParsePublishTopic(topic string) (domainID, chanID, subtopic string, topicType TopicType, err error) {
domainID, chanID, subtopic, topicType, err = ParseTopic(topic)
if err != nil {
return "", "", "", err
return "", "", "", InvalidType, err
}
subtopic, err = ParsePublishSubtopic(subtopic)
if err != nil {
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
return "", "", "", InvalidType, errors.Wrap(ErrMalformedTopic, err)
}
return domainID, chanID, subtopic, nil
return domainID, chanID, subtopic, topicType, nil
}
func ParsePublishSubtopic(subtopic string) (parseSubTopic string, err error) {
@@ -269,17 +280,17 @@ func ParsePublishSubtopic(subtopic string) (parseSubTopic string, err error) {
return subtopic, nil
}
func ParseSubscribeTopic(topic string) (domainID string, chanID string, subtopic string, err error) {
domainID, chanID, subtopic, err = ParseTopic(topic)
func ParseSubscribeTopic(topic string) (domainID string, chanID string, subtopic string, topicType TopicType, err error) {
domainID, chanID, subtopic, topicType, err = ParseTopic(topic)
if err != nil {
return "", "", "", err
return "", "", "", InvalidType, err
}
subtopic, err = ParseSubscribeSubtopic(subtopic)
if err != nil {
return "", "", "", errors.Wrap(ErrMalformedTopic, err)
return "", "", "", InvalidType, errors.Wrap(ErrMalformedTopic, err)
}
return domainID, chanID, subtopic, nil
return domainID, chanID, subtopic, topicType, nil
}
func ParseSubscribeSubtopic(subtopic string) (parseSubTopic string, err error) {
@@ -347,33 +358,61 @@ func EncodeMessageMQTTTopic(m *Message) string {
return topic
}
func EncodeAdapterTopic(domain, channel, subtopic string) string {
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), domain, string(ChannelTopicPrefix), channel)
if subtopic != "" {
topic = topic + "/" + subtopic
func encodeAdapterTopic(domain, channel, subtopic string, topicType TopicType) string {
switch topicType {
case HealthType:
return fmt.Sprintf("%s/%s", string(HealthTopicPrefix), domain)
default:
topic := fmt.Sprintf("%s/%s/%s/%s", string(MsgTopicPrefix), domain, string(ChannelTopicPrefix), channel)
if subtopic != "" {
topic = topic + "/" + subtopic
}
return topic
}
return topic
}
// ParseTopic parses a messaging topic string and returns the domain ID, channel ID, and subtopic.
// Supported formats (leading '/' optional):
//
// m/<domain_id>/c/<channel_id>[/<subtopic>]
// hc/<domain_id>
//
// This is an optimized version with no regex and minimal allocations.
func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
// location of string "m"
func ParseTopic(topic string) (domainID, chanID, subtopic string, topicType TopicType, err error) {
start := 0
// Handle both formats: "/m/domain/c/channel/subtopic" and "m/domain/c/channel/subtopic".
// If topic start with m/ then start is 0 , If topic start with /m/ then start is 1.
n := len(topic)
if n > 0 && topic[0] == '/' {
start = 1
}
if n <= start {
return "", "", "", InvalidType, ErrMalformedTopic
}
// Healthcheck: "hc/<domain_id>"
// Check first because it's shortest and avoids extra work.
if n > start+3 && topic[start:start+2] == HealthTopicPrefix {
if n == start+3 {
// "hc/" with no domain
return "", "", "", InvalidType, ErrMalformedTopic
}
// Domain is the remainder; ensure no extra '/'
domainID = topic[start+3:]
for i := start + 3; i < n; i++ {
if topic[i] == '/' {
return "", "", "", InvalidType, ErrMalformedTopic
}
}
return domainID, "", "", HealthType, nil
}
// Messaging: "m/<domain_id>/c/<channel_id>[/<subtopic>]"
// length check - minimum: "m/<domain_id>/c/" = 5 characters if ignore <domain_id> and in this case start will be 0
// length check - minimum: "/m/<domain_id>/c/" = 6 characters if ignore <domain_id> and in this case start will be 1
if n < start+5 {
return "", "", "", ErrMalformedTopic
return "", "", "", InvalidType, ErrMalformedTopic
}
if topic[start] != MsgTopicPrefix || topic[start+1] != '/' {
return "", "", "", ErrMalformedTopic
return "", "", "", InvalidType, ErrMalformedTopic
}
pos := start + 2
@@ -386,7 +425,7 @@ func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
}
}
if cPos == -1 || cPos == 0 {
return "", "", "", ErrMalformedTopic
return "", "", "", InvalidType, ErrMalformedTopic
}
domainID = topic[pos : pos+cPos]
// skip "/c/"
@@ -394,7 +433,7 @@ func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
// Ensure channel exists
if pos >= n {
return "", "", "", ErrMalformedTopic
return "", "", "", InvalidType, ErrMalformedTopic
}
// Find '/' after channelID
@@ -407,17 +446,15 @@ func ParseTopic(topic string) (domainID, chanID, subtopic string, err error) {
}
if nextSlash == -1 {
// No subtopic
chanID = topic[pos:]
} else {
chanID = topic[pos : pos+nextSlash]
subtopic = topic[pos+nextSlash+1:]
}
// Validate channelID
if len(chanID) == 0 {
return "", "", "", ErrMalformedTopic
return "", "", "", InvalidType, ErrMalformedTopic
}
return domainID, chanID, subtopic, nil
return domainID, chanID, subtopic, MessageType, nil
}
+89 -11
View File
@@ -26,6 +26,7 @@ var (
channelID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
topicFmt = "m/%s/c/%s"
healthTopicFmt = "hc/%s"
subtopic = "subtopic"
topicSubtopicFmt = "m/%s/c/%s/%s"
cachedTopic = fmt.Sprintf(topicSubtopicFmt, domainID, channelID, subtopic)
@@ -56,6 +57,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID string
channelID string
subtopic string
topicType messaging.TopicType
err error
}{
{
@@ -64,6 +66,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
topicType: messaging.MessageType,
err: nil,
},
{
@@ -72,6 +75,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp.data",
topicType: messaging.MessageType,
},
{
desc: "valid topic with subtopic /m/domain/c/channel/extra/extra2",
@@ -79,6 +83,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain",
channelID: "channel",
subtopic: "extra.extra2",
topicType: messaging.MessageType,
},
{
desc: "valid topic without subtopic /m/domain123/c/channel456",
@@ -86,6 +91,7 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
topicType: messaging.MessageType,
},
{
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
@@ -93,6 +99,25 @@ var ParsePublisherTopicTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
topicType: messaging.MessageType,
},
{
desc: "valid health check topic",
topic: fmt.Sprintf(healthTopicFmt, domainID),
domainID: domainID,
channelID: "",
subtopic: "",
topicType: messaging.HealthType,
err: nil,
},
{
desc: "invalid health check topic with empty domain",
topic: "hc/",
domainID: "",
channelID: "",
subtopic: "",
topicType: messaging.InvalidType,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing parts) /m/domain123/c/",
@@ -192,12 +217,13 @@ var ParsePublisherTopicTestCases = []struct {
func TestParsePublishTopic(t *testing.T) {
for _, tc := range ParsePublisherTopicTestCases {
t.Run(tc.desc, func(t *testing.T) {
domainID, channelID, subtopic, err := messaging.ParsePublishTopic(tc.topic)
domainID, channelID, subtopic, topicType, err := messaging.ParsePublishTopic(tc.topic)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID)
assert.Equal(t, tc.subtopic, subtopic)
assert.Equal(t, tc.topicType, topicType)
}
})
}
@@ -207,7 +233,7 @@ func BenchmarkParsePublisherTopic(b *testing.B) {
for _, tc := range ParsePublisherTopicTestCases {
b.Run(tc.desc, func(b *testing.B) {
for b.Loop() {
_, _, _, _ = messaging.ParsePublishTopic(tc.topic)
_, _, _, _, _ = messaging.ParsePublishTopic(tc.topic)
}
})
}
@@ -219,6 +245,7 @@ var ParseSubscribeTestCases = []struct {
domainID string
channelID string
subtopic string
topicType messaging.TopicType
err error
}{
{
@@ -227,6 +254,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
topicType: messaging.MessageType,
},
{
desc: "topic with wildcards + and # /m/domain123/c/channel456/devices/+/temp/#",
@@ -234,6 +262,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.*.temp.>",
topicType: messaging.MessageType,
},
{
desc: "valid topic without subtopic /m/domain123/c/channel456",
@@ -241,6 +270,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "",
topicType: messaging.MessageType,
},
{
desc: "valid topic with trailing slash /m/domain123/c/channel456/devices/temp/",
@@ -248,6 +278,25 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain123",
channelID: "channel456",
subtopic: "devices.temp",
topicType: messaging.MessageType,
},
{
desc: "valid health check topic",
topic: fmt.Sprintf(healthTopicFmt, domainID),
domainID: domainID,
channelID: "",
subtopic: "",
topicType: messaging.HealthType,
err: nil,
},
{
desc: "invalid health check topic with empty domain",
topic: "hc/",
domainID: "",
channelID: "",
subtopic: "",
topicType: messaging.InvalidType,
err: messaging.ErrMalformedTopic,
},
{
desc: "invalid topic format (missing channel) /m/domain123/c/",
@@ -279,6 +328,7 @@ var ParseSubscribeTestCases = []struct {
domainID: "domain*123",
channelID: "channel456",
subtopic: "devices.*.temp.>",
topicType: messaging.MessageType,
},
{
desc: "invalid subtopic /m/domain123/c/channel456/sub/a*b/topic",
@@ -346,12 +396,13 @@ var ParseSubscribeTestCases = []struct {
func TestParseSubscribeTopic(t *testing.T) {
for _, tc := range ParseSubscribeTestCases {
t.Run(tc.desc, func(t *testing.T) {
domainID, channelID, subtopic, err := messaging.ParseSubscribeTopic(tc.topic)
domainID, channelID, subtopic, topicType, err := messaging.ParseSubscribeTopic(tc.topic)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID)
assert.Equal(t, tc.subtopic, subtopic)
assert.Equal(t, tc.topicType, topicType)
}
})
}
@@ -361,7 +412,7 @@ func BenchmarkParseSubscribeTopic(b *testing.B) {
for _, tc := range ParseSubscribeTestCases {
b.Run(tc.desc, func(b *testing.B) {
for b.Loop() {
_, _, _, _ = messaging.ParseSubscribeTopic(tc.topic)
_, _, _, _, _ = messaging.ParseSubscribeTopic(tc.topic)
}
})
}
@@ -579,7 +630,7 @@ func TestResolve(t *testing.T) {
channel: "",
domainID: domainID,
channelID: "",
err: messaging.ErrEmptyRouteID,
err: nil,
},
}
for _, tc := range cases {
@@ -724,17 +775,19 @@ func TestParserPublishTopic(t *testing.T) {
cachedInvalidTopic := "m/invalid-domain/c"
dom, ch, st, err := parser.ParsePublishTopic(context.Background(), cachedTopic, false)
dom, ch, st, tt, err := parser.ParsePublishTopic(context.Background(), cachedTopic, false)
assert.Nil(t, err, fmt.Sprintf("unexpected error while publishing topic: %v", err))
assert.Equal(t, domainID, dom, "expected domainID %s, got %s", domainID, dom)
assert.Equal(t, channelID, ch, "expected channelID %s, got %s", channelID, ch)
assert.Equal(t, subtopic, st, "expected subtopic %s, got %s", subtopic, st)
assert.Equal(t, messaging.MessageType, tt, "expected topic type %v, got %v", messaging.MessageType, tt)
dom, ch, st, err = parser.ParsePublishTopic(context.Background(), cachedInvalidTopic, false)
dom, ch, st, tt, err = parser.ParsePublishTopic(context.Background(), cachedInvalidTopic, false)
assert.NotNil(t, err, "expected error for invalid cached topic")
assert.Equal(t, "", dom, "expected empty domainID for invalid topic")
assert.Equal(t, "", ch, "expected empty channelID for invalid topic")
assert.Equal(t, "", st, "expected empty subtopic for invalid topic")
assert.Equal(t, messaging.InvalidType, tt, "expected unknown topic type for invalid topic")
time.Sleep(10 * time.Millisecond) // Ensure cache is populated
cases := []struct {
@@ -745,6 +798,8 @@ func TestParserPublishTopic(t *testing.T) {
channel string
domainID string
channelID string
subtopic string
topicType messaging.TopicType
domainsErr error
channelsErr error
err error
@@ -757,6 +812,8 @@ func TestParserPublishTopic(t *testing.T) {
channel: uchannelID,
domainID: udomainID,
channelID: uchannelID,
subtopic: subtopic,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -766,6 +823,8 @@ func TestParserPublishTopic(t *testing.T) {
channel: channelID,
domainID: domainID,
channelID: channelID,
subtopic: subtopic,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -794,6 +853,8 @@ func TestParserPublishTopic(t *testing.T) {
channel: validRoute,
domainID: domainID,
channelID: channelID,
subtopic: subtopic,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -807,6 +868,17 @@ func TestParserPublishTopic(t *testing.T) {
domainsErr: svcerr.ErrNotFound,
err: messaging.ErrFailedResolveDomain,
},
{
desc: "valid uncached healthcheck topic",
topic: fmt.Sprintf(healthTopicFmt, domainID),
domain: domainID,
channel: "",
domainID: domainID,
channelID: "",
subtopic: "",
topicType: messaging.HealthType,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
@@ -820,12 +892,13 @@ func TestParserPublishTopic(t *testing.T) {
Id: tc.channelID,
},
}, tc.channelsErr)
domainID, channelID, subtopic, err := parser.ParsePublishTopic(context.Background(), tc.topic, tc.resolve)
domainID, channelID, subtopic, topicType, err := parser.ParsePublishTopic(context.Background(), tc.topic, tc.resolve)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, domainID, "expected domainID %s, got %s", tc.domainID, domainID)
assert.Equal(t, tc.channelID, channelID, "expected channelID %s, got %s", tc.channelID, channelID)
assert.Equal(t, subtopic, "subtopic", "expected subtopic %s, got %s", "subtopic", subtopic)
assert.Equal(t, tc.subtopic, subtopic, "expected subtopic %s, got %s", tc.subtopic, subtopic)
assert.Equal(t, tc.topicType, topicType, "expected topic type %v, got %v", tc.topicType, topicType)
}
domainsCall.Unset()
channelsCall.Unset()
@@ -842,7 +915,7 @@ func BenchmarkParserPublishTopic(b *testing.B) {
for _, tc := range ParsePublisherTopicTestCases {
b.Run(tc.desc, func(b *testing.B) {
for b.Loop() {
_, _, _, _ = parser.ParsePublishTopic(context.Background(), tc.topic, false)
_, _, _, _, _ = parser.ParsePublishTopic(context.Background(), tc.topic, false)
}
})
}
@@ -861,6 +934,7 @@ func TestParserSubscribeTopic(t *testing.T) {
domainID string
channelID string
subtopic string
topicType messaging.TopicType
domainsErr error
channelsErr error
err error
@@ -873,6 +947,7 @@ func TestParserSubscribeTopic(t *testing.T) {
channel: channelID,
domainID: domainID,
channelID: channelID,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -884,6 +959,7 @@ func TestParserSubscribeTopic(t *testing.T) {
domainID: domainID,
channelID: channelID,
subtopic: subtopic,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -894,6 +970,7 @@ func TestParserSubscribeTopic(t *testing.T) {
channel: validRoute,
domainID: domainID,
channelID: channelID,
topicType: messaging.MessageType,
err: nil,
},
{
@@ -930,12 +1007,13 @@ func TestParserSubscribeTopic(t *testing.T) {
Id: tc.channelID,
},
}, tc.channelsErr)
dom, ch, st, err := parser.ParseSubscribeTopic(context.Background(), tc.topic, tc.resolve)
dom, ch, st, tt, err := parser.ParseSubscribeTopic(context.Background(), tc.topic, tc.resolve)
assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
if err == nil {
assert.Equal(t, tc.domainID, dom, "expected domainID %s, got %s", tc.domainID, dom)
assert.Equal(t, tc.channelID, ch, "expected channelID %s, got %s", tc.channelID, ch)
assert.Equal(t, tc.subtopic, st, "expected subtopic %s, got %s", tc.subtopic, st)
assert.Equal(t, tc.topicType, tt, "expected topic type %v, got %v", tc.topicType, tt)
}
domainsCall.Unset()
channelsCall.Unset()
+3 -3
View File
@@ -84,7 +84,7 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
return errClientNotInitialized
}
domainID, channelID, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
domainID, channelID, _, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
@@ -113,7 +113,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
}
for _, topic := range *topics {
domainID, channelID, _, err := h.parser.ParseSubscribeTopic(ctx, topic, true)
domainID, channelID, _, _, err := h.parser.ParseSubscribeTopic(ctx, topic, true)
if err != nil {
return err
}
@@ -141,7 +141,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return nil
}
domainID, channelID, subtopic, err := h.parser.ParsePublishTopic(ctx, *topic, true)
domainID, channelID, subtopic, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return errors.Wrap(errFailedPublish, err)
}