SMQ-2706 - Add domain id to message topic (#2765)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-04-11 10:11:39 +03:00
committed by GitHub
parent 299cee7771
commit 6151be72ce
36 changed files with 316 additions and 152 deletions
+10 -1
View File
@@ -37,8 +37,12 @@ servers:
- user-password: []
channels:
c/{channelID}/m/{subtopic}:
/m/{domainID}/c/{channelID}/{subtopic}:
parameters:
domainID:
$ref: '#/components/parameters/domainID'
in: path
required: true
channelID:
$ref: '#/components/parameters/channelID'
in: path
@@ -87,6 +91,11 @@ components:
```
parameters:
domainID:
description: Domain ID associated with the channel and client.
schema:
type: string
format: uuid
channelID:
description: Channel ID connected to the Client ID defined in the username.
schema:
+10 -1
View File
@@ -32,8 +32,12 @@ servers:
default: '8186'
channels:
'c/{channelID}/m/{subtopic}':
'm/{domainID}/c/{channelID}/{subtopic}':
parameters:
domainID:
$ref: '#/components/parameters/domainID'
in: path
required: true
channelID:
$ref: '#/components/parameters/channelID'
in: path
@@ -125,6 +129,11 @@ components:
```
parameters:
domainID:
description: Domain ID associated with the channel and client.
schema:
type: string
format: uuid
channelID:
description: Channel ID connected to the Client ID defined in the username.
schema:
+13 -4
View File
@@ -27,7 +27,7 @@ tags:
url: https://docs.supermq.abstractmachines.fr/
paths:
/c/{id}/m:
/m/{domainID}/c/{channelID}:
post:
summary: Sends message to the communication channel
description: |
@@ -36,7 +36,8 @@ paths:
tags:
- messages
parameters:
- $ref: "#/components/parameters/ID"
- $ref: "#/components/parameters/domainID"
- $ref: "#/components/parameters/channelID"
requestBody:
$ref: "#/components/requestBodies/MessageReq"
responses:
@@ -129,8 +130,16 @@ components:
$ref: "#/components/schemas/SenMLRecord"
parameters:
ID:
name: id
domainID:
name: domainID
description: Unique domain identifier.
in: path
schema:
type: string
format: uuid
required: true
channelID:
name: channelID
description: Unique channel identifier.
in: path
schema:
+14 -2
View File
@@ -8,11 +8,15 @@ import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/channels"
dom "github.com/absmach/supermq/domains"
pkgDomains "github.com/absmach/supermq/pkg/domains"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
)
var errDisabledDomain = errors.New("domain is disabled or frozen")
type Service interface {
Authorize(ctx context.Context, req channels.AuthzReq) error
UnsetParentGroupFromChannels(ctx context.Context, parentGroupID string) error
@@ -24,15 +28,23 @@ type service struct {
repo channels.Repository
evaluator policies.Evaluator
policy policies.Service
domains pkgDomains.Authorization
}
var _ Service = (*service)(nil)
func New(repo channels.Repository, evaluator policies.Evaluator, policy policies.Service) Service {
return service{repo, evaluator, policy}
func New(repo channels.Repository, evaluator policies.Evaluator, policy policies.Service, domains pkgDomains.Authorization) Service {
return service{repo, evaluator, policy, domains}
}
func (svc service) Authorize(ctx context.Context, req channels.AuthzReq) error {
d, err := svc.domains.RetrieveEntity(ctx, req.DomainID)
if err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
if d.Status != dom.EnabledStatus {
return errors.Wrap(svcerr.ErrAuthorization, errDisabledDomain)
}
switch req.ClientType {
case policies.UserType:
permission, err := req.Type.Permission()
+3 -3
View File
@@ -7,16 +7,16 @@ import "github.com/spf13/cobra"
var cmdMessages = []cobra.Command{
{
Use: "send <channel_id.subtopic> <JSON_string> <client_secret>",
Use: "send <channel_id.subtopic> <JSON_string> <domain_id> <client_secret>",
Short: "Send messages",
Long: `Sends message on the channel`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 {
if len(args) != 4 {
logUsageCmd(*cmd, cmd.Use)
return
}
if err := sdk.SendMessage(cmd.Context(), args[0], args[1], args[2]); err != nil {
if err := sdk.SendMessage(cmd.Context(), args[0], args[1], args[2], args[3]); err != nil {
logErrorCmd(*cmd, err)
return
}
+4 -1
View File
@@ -37,6 +37,7 @@ func TestSendMesageCmd(t *testing.T) {
args: []string{
channel.ID,
message,
domainID,
client.Credentials.Secret,
},
logType: okLog,
@@ -47,6 +48,7 @@ func TestSendMesageCmd(t *testing.T) {
channel.ID,
message,
client.Credentials.Secret,
domainID,
extraArg,
},
logType: usageLog,
@@ -56,6 +58,7 @@ func TestSendMesageCmd(t *testing.T) {
args: []string{
channel.ID,
message,
domainID,
"invalid_secret",
},
sdkErr: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(svcerr.ErrAuthorization, svcerr.ErrNotFound)), http.StatusBadRequest),
@@ -66,7 +69,7 @@ func TestSendMesageCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("SendMessage", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr)
sdkCall := sdkMock.On("SendMessage", mock.Anything, tc.args[0], tc.args[1], tc.args[2], tc.args[3]).Return(tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{sendCmd}, tc.args...)...)
switch tc.logType {
+4 -3
View File
@@ -31,6 +31,7 @@ import (
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
pkgDomains "github.com/absmach/supermq/pkg/domains"
dconsumer "github.com/absmach/supermq/pkg/domains/events/consumer"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer"
@@ -225,7 +226,7 @@ func main() {
defer groupsHandler.Close()
logger.Info("Groups gRPC client successfully connected to groups gRPC server " + groupsHandler.Secure())
svc, psvc, err := newService(ctx, db, dbConfig, authz, policyEvaluator, policyService, cfg, tracer, clientsClient, groupsClient, logger)
svc, psvc, err := newService(ctx, db, dbConfig, authz, policyEvaluator, policyService, cfg, tracer, clientsClient, groupsClient, domAuthz, logger)
if err != nil {
logger.Error(fmt.Sprintf("failed to create services: %s", err))
exitCode = 1
@@ -298,7 +299,7 @@ func main() {
func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, authz smqauthz.Authorization,
pe policies.Evaluator, ps policies.Service, cfg config, tracer trace.Tracer, clientsClient grpcClientsV1.ClientsServiceClient,
groupsClient grpcGroupsV1.GroupsServiceClient, logger *slog.Logger,
groupsClient grpcGroupsV1.GroupsServiceClient, da pkgDomains.Authorization, logger *slog.Logger,
) (channels.Service, pChannels.Service, error) {
database := pg.NewDatabase(db, dbConfig, tracer)
repo := postgres.NewRepository(database)
@@ -335,7 +336,7 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, auth
}
svc = middleware.LoggingMiddleware(svc, logger)
psvc := pChannels.New(repo, pe, ps)
psvc := pChannels.New(repo, pe, ps, da)
return svc, psvc, err
}
+19 -10
View File
@@ -29,12 +29,12 @@ type Service interface {
// Key is used to authorize publisher.
Publish(ctx context.Context, key string, msg *messaging.Message) error
// Subscribes to channel with specified id, subtopic and adds subscription to
// Subscribes to channel with specified id, domainID, subtopic and adds subscription to
// service map of subscriptions under given ID.
Subscribe(ctx context.Context, key, chanID, subtopic string, c Client) error
Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error
// Unsubscribe method is used to stop observing resource.
Unsubscribe(ctx context.Context, key, chanID, subptopic, token string) error
Unsubscribe(ctx context.Context, key, domainID, chanID, subptopic, token string) error
// DisconnectHandler method is used to disconnected the client
DisconnectHandler(ctx context.Context, chanID, subptopic, token string) error
@@ -72,6 +72,7 @@ func (svc *adapterService) Publish(ctx context.Context, key string, msg *messagi
}
authzRes, err := svc.channels.Authorize(ctx, &grpcChannelsV1.AuthzReq{
DomainId: msg.GetDomain(),
ClientId: authnRes.GetId(),
ClientType: policies.ClientType,
Type: uint32(connections.Publish),
@@ -89,7 +90,7 @@ func (svc *adapterService) Publish(ctx context.Context, key string, msg *messagi
return svc.pubsub.Publish(ctx, msg.GetChannel(), msg)
}
func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic string, c Client) error {
func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error {
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
ClientSecret: key,
})
@@ -102,6 +103,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic
clientID := authnRes.GetId()
authzRes, err := svc.channels.Authorize(ctx, &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ClientId: clientID,
ClientType: policies.ClientType,
Type: uint32(connections.Subscribe),
@@ -119,7 +121,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic
subject = fmt.Sprintf("%s.%s", subject, subtopic)
}
authzc := newAuthzClient(clientID, chanID, subtopic, svc.channels, c)
authzc := newAuthzClient(clientID, domainID, chanID, subtopic, svc.channels, c)
subCfg := messaging.SubscriberConfig{
ID: c.Token(),
ClientID: clientID,
@@ -129,7 +131,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, chanID, subtopic
return svc.pubsub.Subscribe(ctx, subCfg)
}
func (svc *adapterService) Unsubscribe(ctx context.Context, key, chanID, subtopic, token string) error {
func (svc *adapterService) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error {
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
ClientSecret: key,
})
@@ -141,7 +143,7 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, key, chanID, subtopi
}
authzRes, err := svc.channels.Authorize(ctx, &grpcChannelsV1.AuthzReq{
DomainId: "",
DomainId: domainID,
ClientId: authnRes.GetId(),
ClientType: policies.ClientType,
Type: uint32(connections.Subscribe),
@@ -182,17 +184,24 @@ type authzClient interface {
type ac struct {
clientID string
channelID string
domainID string
subTopic string
channels grpcChannelsV1.ChannelsServiceClient
client Client
}
func newAuthzClient(clientID, channelID, subTopic string, channels grpcChannelsV1.ChannelsServiceClient, client Client) authzClient {
return ac{clientID, channelID, subTopic, channels, client}
func newAuthzClient(clientID, domainID, channelID, subTopic string, channels grpcChannelsV1.ChannelsServiceClient, client Client) authzClient {
return ac{clientID, channelID, domainID, subTopic, channels, client}
}
func (a ac) Handle(m *messaging.Message) error {
res, err := a.channels.Authorize(context.Background(), &grpcChannelsV1.AuthzReq{ClientId: a.clientID, ClientType: policies.ClientType, ChannelId: a.channelID, Type: uint32(connections.Subscribe)})
res, err := a.channels.Authorize(context.Background(), &grpcChannelsV1.AuthzReq{
ClientId: a.clientID,
ClientType: policies.ClientType,
ChannelId: a.channelID,
DomainId: a.domainID,
Type: uint32(connections.Subscribe),
})
if err != nil {
if disErr := a.Cancel(); disErr != nil {
return errors.Wrap(err, errors.Wrap(errFailedToDisconnectClient, disErr))
+7 -4
View File
@@ -33,6 +33,7 @@ func (lm *loggingMiddleware) Publish(ctx context.Context, key string, msg *messa
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("channel_id", msg.GetChannel()),
slog.String("domain_id", msg.GetDomain()),
}
if msg.GetSubtopic() != "" {
args = append(args, slog.String("subtopic", msg.GetSubtopic()))
@@ -50,11 +51,12 @@ func (lm *loggingMiddleware) Publish(ctx context.Context, key string, msg *messa
// Subscribe logs the subscribe request. It logs the channel ID, subtopic (if any) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, key, chanID, subtopic string, c coap.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("channel_id", chanID),
slog.String("domain_id", domainID),
}
if subtopic != "" {
args = append(args, slog.String("subtopic", subtopic))
@@ -67,16 +69,17 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, key, chanID, subtopi
lm.logger.Info("Subscribe completed successfully", args...)
}(time.Now())
return lm.svc.Subscribe(ctx, key, chanID, subtopic, c)
return lm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c)
}
// Unsubscribe logs the unsubscribe request. It logs the channel ID, subtopic (if any) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, key, chanID, subtopic, token string) (err error) {
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("channel_id", chanID),
slog.String("domain_id", domainID),
}
if subtopic != "" {
args = append(args, slog.String("subtopic", subtopic))
@@ -89,7 +92,7 @@ func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, key, chanID, subto
lm.logger.Info("Unsubscribe completed successfully", args...)
}(time.Now())
return lm.svc.Unsubscribe(ctx, key, chanID, subtopic, token)
return lm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token)
}
// DisconnectHandler logs the disconnect handler. It logs the channel ID, subtopic (if any) and the time it took to complete the request.
+4 -4
View File
@@ -42,23 +42,23 @@ func (mm *metricsMiddleware) Publish(ctx context.Context, key string, msg *messa
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, key, chanID, subtopic string, c coap.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Subscribe(ctx, key, chanID, subtopic, c)
return mm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c)
}
// Unsubscribe instruments Unsubscribe method with metrics.
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, key, chanID, subtopic, token string) error {
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error {
defer func(begin time.Time) {
mm.counter.With("method", "unsubscribe").Add(1)
mm.latency.With("method", "unsubscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Unsubscribe(ctx, key, chanID, subtopic, token)
return mm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token)
}
// DisconnectHandler instruments DisconnectHandler method with metrics.
+8 -6
View File
@@ -33,11 +33,12 @@ const (
startObserve = 0 // observe option value that indicates start of observation
)
var channelPartRegExp = regexp.MustCompile(`^/c/([\w\-]+)/m(/[^?]*)?(\?.*)?$`)
var channelPartRegExp = regexp.MustCompile(`^/m/([\w\-]+)/c/([\w\-]+)(/[^?]*)?(\?.*)?$`)
const (
numGroups = 3 // entire expression + channel group + subtopic group
channelGroup = 2 // channel group is second in channel regexp
numGroups = 4 // entire expression+ domain group + channel group + subtopic group
domainGroup = 1 // domain group is first in channel regexp
channelGroup = 3 // channel group is third in channel regexp
)
var (
@@ -134,9 +135,9 @@ func handleGet(m *mux.Message, w mux.ResponseWriter, msg *messaging.Message, key
w.Conn().AddOnClose(func() {
_ = service.DisconnectHandler(context.Background(), msg.GetChannel(), msg.GetSubtopic(), c.Token())
})
return service.Subscribe(w.Conn().Context(), key, msg.GetChannel(), msg.GetSubtopic(), c)
return service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c)
}
return service.Unsubscribe(w.Conn().Context(), key, msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
return service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String())
}
func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
@@ -158,7 +159,8 @@ func decodeMessage(msg *mux.Message) (*messaging.Message, error) {
}
ret := &messaging.Message{
Protocol: protocol,
Channel: channelParts[1],
Domain: channelParts[domainGroup],
Channel: channelParts[2],
Subtopic: st,
Payload: []byte{},
Created: time.Now().UnixNano(),
+7 -5
View File
@@ -44,23 +44,25 @@ func (tm *tracingServiceMiddleware) Publish(ctx context.Context, key string, msg
}
// Subscribe traces a CoAP subscribe operation.
func (tm *tracingServiceMiddleware) Subscribe(ctx context.Context, key, chanID, subtopic string, c coap.Client) error {
func (tm *tracingServiceMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) error {
ctx, span := tm.tracer.Start(ctx, subscribeOP, trace.WithAttributes(
attribute.String("channel_id", chanID),
attribute.String("domain_id", domainID),
attribute.String("subtopic", subtopic),
))
defer span.End()
return tm.svc.Subscribe(ctx, key, chanID, subtopic, c)
return tm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c)
}
// Unsubscribe traces a CoAP unsubscribe operation.
func (tm *tracingServiceMiddleware) Unsubscribe(ctx context.Context, key, chanID, subptopic, token string) error {
func (tm *tracingServiceMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error {
ctx, span := tm.tracer.Start(ctx, unsubscribeOP, trace.WithAttributes(
attribute.String("channel_id", chanID),
attribute.String("subtopic", subptopic),
attribute.String("domain_id", domainID),
attribute.String("subtopic", subtopic),
))
defer span.End()
return tm.svc.Unsubscribe(ctx, key, chanID, subptopic, token)
return tm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token)
}
// DisconnectHandler traces a CoAP disconnect operation.
+29 -4
View File
@@ -37,7 +37,11 @@ const (
invalidValue = "invalid"
)
var clientID = testsutil.GenerateUUID(&testing.T{})
var (
clientID = testsutil.GenerateUUID(&testing.T{})
chanID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
)
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (session.Handler, *pubsub.PubSub) {
pub := new(pubsub.PubSub)
@@ -93,7 +97,6 @@ func TestPublish(t *testing.T) {
clients := new(climocks.ClientsServiceClient)
authn := new(authnMocks.Authentication)
channels := new(chmocks.ChannelsServiceClient)
chanID := "1"
ctSenmlJSON := "application/senml+json"
ctSenmlCBOR := "application/senml+cbor"
ctJSON := "application/json"
@@ -112,6 +115,7 @@ func TestPublish(t *testing.T) {
cases := []struct {
desc string
domainID string
chanID string
msg string
contentType string
@@ -126,6 +130,7 @@ func TestPublish(t *testing.T) {
}{
{
desc: "publish message successfully",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
@@ -136,6 +141,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with application/senml+cbor content-type",
domainID: domainID,
chanID: chanID,
msg: msgCBOR,
contentType: ctSenmlCBOR,
@@ -146,6 +152,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with application/json content-type",
domainID: domainID,
chanID: chanID,
msg: msgJSON,
contentType: ctJSON,
@@ -156,6 +163,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with empty key",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
@@ -164,6 +172,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with basic auth",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
@@ -175,6 +184,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with invalid key",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
@@ -184,6 +194,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message with invalid basic auth",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
@@ -194,6 +205,7 @@ func TestPublish(t *testing.T) {
},
{
desc: "publish message without content type",
domainID: domainID,
chanID: chanID,
msg: msg,
contentType: "",
@@ -203,7 +215,8 @@ func TestPublish(t *testing.T) {
authzRes: &grpcChannelsV1.AuthzRes{Authorized: true},
},
{
desc: "publish message to invalid channel",
desc: "publish message to empty channel",
domainID: domainID,
chanID: "",
msg: msg,
contentType: ctSenmlJSON,
@@ -212,12 +225,24 @@ func TestPublish(t *testing.T) {
authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authzRes: &grpcChannelsV1.AuthzRes{Authorized: false},
},
{
desc: "publish message with invalid domain ID",
domainID: invalidValue,
chanID: chanID,
msg: msg,
contentType: ctSenmlJSON,
key: clientKey,
status: http.StatusUnauthorized,
authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authzRes: &grpcChannelsV1.AuthzRes{Authorized: false},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: tc.key}).Return(tc.authnRes, tc.authnErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
DomainId: tc.domainID,
ChannelId: tc.chanID,
ClientId: clientID,
ClientType: policies.ClientType,
@@ -227,7 +252,7 @@ func TestPublish(t *testing.T) {
req := testRequest{
client: ts.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/c/%s/m", ts.URL, tc.chanID),
url: fmt.Sprintf("%s/m/%s/c/%s", ts.URL, tc.domainID, tc.chanID),
contentType: tc.contentType,
token: tc.key,
body: strings.NewReader(tc.msg),
+2 -2
View File
@@ -33,14 +33,14 @@ func MakeHandler(logger *slog.Logger, instanceID string) http.Handler {
}
r := chi.NewRouter()
r.Post("/c/{chanID}/m", otelhttp.NewHandler(kithttp.NewServer(
r.Post("/m/{domainID}/c/{chanID}", otelhttp.NewHandler(kithttp.NewServer(
sendMessageEndpoint(),
decodeRequest,
api.EncodeResponse,
opts...,
), "publish").ServeHTTP)
r.Post("/c/{chanID}/m/*", otelhttp.NewHandler(kithttp.NewServer(
r.Post("/m/{domainID}/c/{chanID}/*", otelhttp.NewHandler(kithttp.NewServer(
sendMessageEndpoint(),
decodeRequest,
api.EncodeResponse,
+13 -10
View File
@@ -55,7 +55,7 @@ var (
errFailedParseSubtopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to parse subtopic"))
)
var channelRegExp = regexp.MustCompile(`^\/?c\/([\w\-]+)\/m(\/[^?]*)?(\?.*)?$`)
var channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
// Event implements events.Event interface.
type handler struct {
@@ -153,13 +153,14 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
chanID, subtopic, err := parseTopic(*topic)
domainID, chanID, subtopic, err := parseTopic(*topic)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, err)
}
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: chanID,
Subtopic: subtopic,
Payload: *payload,
@@ -167,6 +168,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
}
ar := &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ClientId: clientID,
ClientType: clientType,
ChannelId: msg.Channel,
@@ -208,23 +210,24 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func parseTopic(topic string) (string, string, error) {
func parseTopic(topic string) (string, string, string, error) {
// Topics are in the format:
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 2 {
return "", "", errors.Wrap(errFailedPublish, errMalformedTopic)
if len(channelParts) < 3 {
return "", "", "", errors.Wrap(errFailedPublish, errMalformedTopic)
}
chanID := channelParts[1]
subtopic := channelParts[2]
domainID := channelParts[1]
chanID := channelParts[2]
subtopic := channelParts[3]
subtopic, err := parseSubtopic(subtopic)
if err != nil {
return "", "", errors.Wrap(errFailedParseSubtopic, err)
return "", "", "", errors.Wrap(errFailedParseSubtopic, err)
}
return chanID, subtopic, nil
return domainID, chanID, subtopic, nil
}
func parseSubtopic(subtopic string) (string, error) {
+6 -5
View File
@@ -35,14 +35,15 @@ const (
chanID = "123e4567-e89b-12d3-a456-000000000001"
invalidID = "invalidID"
invalidValue = "invalidValue"
invalidChannelIDTopic = "c/**/m"
invalidChannelIDTopic = "m/**/c"
)
var (
topicMsg = "c/%s/m"
subtopicMsg = "c/%s/m/subtopic"
topic = fmt.Sprintf(topicMsg, chanID)
subtopic = fmt.Sprintf(subtopicMsg, chanID)
domainID = testsutil.GenerateUUID(&testing.T{})
topicMsg = "/m/%s/c/%s"
subtopicMsg = "/m/%s/c/%s/subtopic"
topic = fmt.Sprintf(topicMsg, domainID, chanID)
subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID)
invalidTopic = invalidValue
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
sessionClient = session.Session{
+4 -4
View File
@@ -25,7 +25,7 @@ const (
var (
errFailedSession = errors.New("failed to obtain session from context")
errMalformedTopic = errors.New("malformed topic")
channelRegExp = regexp.MustCompile(`^\/?c\/([\w\-]+)\/m(\/[^?]*)?(\?.*)?$`)
channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
)
// EventStore is a struct used to store event streams in Redis.
@@ -142,12 +142,12 @@ func (es *eventStore) Disconnect(ctx context.Context) error {
func parseTopic(topic string) (string, string, error) {
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 2 {
if len(channelParts) < 3 {
return "", "", errMalformedTopic
}
chanID := channelParts[1]
subtopic := channelParts[2]
chanID := channelParts[2]
subtopic := channelParts[3]
if subtopic == "" {
return subtopic, chanID, nil
+1 -1
View File
@@ -49,7 +49,7 @@ func handle(ctx context.Context, pub messaging.Publisher, logger *slog.Logger) h
}
// Use concatenation instead of fmt.Sprintf for the
// sake of simplicity and performance.
topic := "c/" + msg.GetChannel() + "/m"
topic := "m/" + msg.GetDomain() + "/c/" + msg.GetChannel()
if msg.GetSubtopic() != "" {
topic = topic + "/" + strings.ReplaceAll(msg.GetSubtopic(), ".", "/")
}
+13 -8
View File
@@ -57,7 +57,7 @@ var (
var (
errInvalidUserId = errors.New("invalid user id")
channelRegExp = regexp.MustCompile(`^\/?c\/([\w\-]+)\/m(\/[^?]*)?(\?.*)?$`)
channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
)
// Event implements events.Event interface.
@@ -158,16 +158,18 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(ErrFailedPublish, ErrClientNotInitialized)
}
h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic))
// Topics are in the format:
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
channelParts := channelRegExp.FindStringSubmatch(*topic)
if len(channelParts) < 2 {
if len(channelParts) < 3 {
return errors.Wrap(ErrFailedPublish, ErrMalformedTopic)
}
chanID := channelParts[1]
subtopic := channelParts[2]
domainID := channelParts[1]
chanID := channelParts[2]
subtopic := channelParts[3]
subtopic, err := parseSubtopic(subtopic)
if err != nil {
@@ -176,6 +178,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: chanID,
Subtopic: subtopic,
Publisher: s.Username,
@@ -225,23 +228,25 @@ func (h *handler) Disconnect(ctx context.Context) error {
func (h *handler) authAccess(ctx context.Context, clientID, topic string, msgType connections.ConnType) error {
// Topics are in the format:
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
if !channelRegExp.MatchString(topic) {
return ErrMalformedTopic
}
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 1 {
if len(channelParts) < 3 {
return ErrMalformedTopic
}
chanID := channelParts[1]
domainID := channelParts[1]
chanID := channelParts[2]
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: clientID,
ClientType: policies.ClientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
+7 -4
View File
@@ -36,17 +36,18 @@ const (
clientID = "clientID"
clientID1 = "clientID1"
subtopic = "testSubtopic"
invalidChannelIDTopic = "c/**/m"
invalidChannelIDTopic = "m/**/c"
)
var (
topicMsg = "c/%s/m"
topic = fmt.Sprintf(topicMsg, chanID)
domainID = testsutil.GenerateUUID(&testing.T{})
topicMsg = "/m/%s/c/%s"
topic = fmt.Sprintf(topicMsg, domainID, chanID)
invalidTopic = invalidValue
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
topics = []string{topic}
invalidTopics = []string{invalidValue}
invalidChanIDTopics = []string{fmt.Sprintf(topicMsg, invalidValue)}
invalidChanIDTopics = []string{fmt.Sprintf(topicMsg, domainID, invalidValue)}
// Test log messages for cases the handler does not provide a return value.
logBuffer = bytes.Buffer{}
sessionClient = session.Session{
@@ -213,6 +214,7 @@ func TestAuthPublish(t *testing.T) {
ctx = session.NewContext(ctx, tc.session)
}
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ChannelId: chanID,
ClientId: clientID,
ClientType: policies.ClientType,
@@ -288,6 +290,7 @@ func TestAuthSubscribe(t *testing.T) {
ctx = session.NewContext(ctx, tc.session)
}
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ChannelId: tc.channelID,
ClientId: clientID1,
ClientType: policies.ClientType,
+21 -12
View File
@@ -28,11 +28,12 @@ const (
type Message struct {
state protoimpl.MessageState `protogen:"open.v1"`
Channel string `protobuf:"bytes,1,opt,name=channel,proto3" json:"channel,omitempty"`
Subtopic string `protobuf:"bytes,2,opt,name=subtopic,proto3" json:"subtopic,omitempty"`
Publisher string `protobuf:"bytes,3,opt,name=publisher,proto3" json:"publisher,omitempty"`
Protocol string `protobuf:"bytes,4,opt,name=protocol,proto3" json:"protocol,omitempty"`
Payload []byte `protobuf:"bytes,5,opt,name=payload,proto3" json:"payload,omitempty"`
Created int64 `protobuf:"varint,6,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds
Domain string `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"`
Subtopic string `protobuf:"bytes,3,opt,name=subtopic,proto3" json:"subtopic,omitempty"`
Publisher string `protobuf:"bytes,4,opt,name=publisher,proto3" json:"publisher,omitempty"`
Protocol string `protobuf:"bytes,5,opt,name=protocol,proto3" json:"protocol,omitempty"`
Payload []byte `protobuf:"bytes,6,opt,name=payload,proto3" json:"payload,omitempty"`
Created int64 `protobuf:"varint,7,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -74,6 +75,13 @@ func (x *Message) GetChannel() string {
return ""
}
func (x *Message) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
func (x *Message) GetSubtopic() string {
if x != nil {
return x.Subtopic
@@ -113,14 +121,15 @@ var File_pkg_messaging_message_proto protoreflect.FileDescriptor
const file_pkg_messaging_message_proto_rawDesc = "" +
"\n" +
"\x1bpkg/messaging/message.proto\x12\tmessaging\"\xad\x01\n" +
"\x1bpkg/messaging/message.proto\x12\tmessaging\"\xc5\x01\n" +
"\aMessage\x12\x18\n" +
"\achannel\x18\x01 \x01(\tR\achannel\x12\x1a\n" +
"\bsubtopic\x18\x02 \x01(\tR\bsubtopic\x12\x1c\n" +
"\tpublisher\x18\x03 \x01(\tR\tpublisher\x12\x1a\n" +
"\bprotocol\x18\x04 \x01(\tR\bprotocol\x12\x18\n" +
"\apayload\x18\x05 \x01(\fR\apayload\x12\x18\n" +
"\acreated\x18\x06 \x01(\x03R\acreatedB\rZ\v./messagingb\x06proto3"
"\achannel\x18\x01 \x01(\tR\achannel\x12\x16\n" +
"\x06domain\x18\x02 \x01(\tR\x06domain\x12\x1a\n" +
"\bsubtopic\x18\x03 \x01(\tR\bsubtopic\x12\x1c\n" +
"\tpublisher\x18\x04 \x01(\tR\tpublisher\x12\x1a\n" +
"\bprotocol\x18\x05 \x01(\tR\bprotocol\x12\x18\n" +
"\apayload\x18\x06 \x01(\fR\apayload\x12\x18\n" +
"\acreated\x18\a \x01(\x03R\acreatedB\rZ\v./messagingb\x06proto3"
var (
file_pkg_messaging_message_proto_rawDescOnce sync.Once
+6 -5
View File
@@ -9,9 +9,10 @@ option go_package = "./messaging";
// Message represents a message emitted by the SuperMQ adapters layer.
message Message {
string channel = 1;
string subtopic = 2;
string publisher = 3;
string protocol = 4;
bytes payload = 5;
int64 created = 6; // Unix timestamp in nanoseconds
string domain = 2;
string subtopic = 3;
string publisher = 4;
string protocol = 5;
bytes payload = 6;
int64 created = 7; // Unix timestamp in nanoseconds
}
+2 -2
View File
@@ -15,7 +15,7 @@ import (
const channelParts = 2
func (sdk mgSDK) SendMessage(ctx context.Context, chanName, msg, key string) errors.SDKError {
func (sdk mgSDK) SendMessage(ctx context.Context, chanName, msg, domainID, key string) errors.SDKError {
chanNameParts := strings.SplitN(chanName, ".", channelParts)
chanID := chanNameParts[0]
subtopicPart := ""
@@ -23,7 +23,7 @@ func (sdk mgSDK) SendMessage(ctx context.Context, chanName, msg, key string) err
subtopicPart = fmt.Sprintf("/%s", strings.ReplaceAll(chanNameParts[1], ".", "/"))
}
reqURL := fmt.Sprintf("%s/c/%s/m%s", sdk.httpAdapterURL, chanID, subtopicPart)
reqURL := fmt.Sprintf("%s/m/%s/c/%s%s", sdk.httpAdapterURL, domainID, chanID, subtopicPart)
_, _, err := sdk.processRequest(ctx, http.MethodPost, reqURL, ClientPrefix+key, []byte(msg), nil, http.StatusAccepted)
+20 -1
View File
@@ -63,6 +63,7 @@ func TestSendMessage(t *testing.T) {
msg := `[{"n":"current","t":-1,"v":1.6}]`
clientKey := "clientKey"
channelID := "channelID"
domainID := "domainID"
sdkConf := sdk.Config{
HTTPAdapterURL: ts.URL,
@@ -75,6 +76,7 @@ func TestSendMessage(t *testing.T) {
cases := []struct {
desc string
chanName string
domainID string
msg string
clientKey string
authRes *grpcClientsV1.AuthnRes
@@ -85,6 +87,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message successfully",
chanName: channelID,
domainID: domainID,
msg: msg,
clientKey: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""},
@@ -95,6 +98,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message with empty client key",
chanName: channelID,
domainID: domainID,
msg: msg,
clientKey: "",
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
@@ -105,6 +109,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message with invalid client key",
chanName: channelID,
domainID: domainID,
msg: msg,
clientKey: "invalid",
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
@@ -115,6 +120,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message with invalid channel ID",
chanName: wrongID,
domainID: domainID,
msg: msg,
clientKey: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
@@ -125,6 +131,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message with empty message body",
chanName: channelID,
domainID: domainID,
msg: "",
clientKey: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""},
@@ -135,6 +142,7 @@ func TestSendMessage(t *testing.T) {
{
desc: "publish message with channel subtopic",
chanName: channelID + ".subtopic",
domainID: domainID,
msg: msg,
clientKey: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""},
@@ -142,13 +150,24 @@ func TestSendMessage(t *testing.T) {
svcErr: nil,
err: nil,
},
{
desc: "publish message with invalid domain ID",
chanName: channelID,
domainID: wrongID,
msg: msg,
clientKey: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
authErr: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authzCall := clientsGRPCClient.On("Authenticate", mock.Anything, mock.Anything).Return(tc.authRes, tc.authErr)
authnCall := channelsGRPCClient.On("Authorize", mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
svcCall := pub.On("Publish", mock.Anything, channelID, mock.Anything).Return(tc.svcErr)
err := mgsdk.SendMessage(context.Background(), tc.chanName, tc.msg, tc.clientKey)
err := mgsdk.SendMessage(context.Background(), tc.chanName, tc.msg, tc.domainID, tc.clientKey)
assert.Equal(t, tc.err, err)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "Publish", mock.Anything, channelID, mock.Anything)
+10 -9
View File
@@ -5644,16 +5644,16 @@ func (_c *SDK_SendInvitation_Call) RunAndReturn(run func(ctx context.Context, in
}
// SendMessage provides a mock function for the type SDK
func (_mock *SDK) SendMessage(ctx context.Context, chanID string, msg string, key string) errors.SDKError {
ret := _mock.Called(ctx, chanID, msg, key)
func (_mock *SDK) SendMessage(ctx context.Context, chanID string, msg string, domainID string, key string) errors.SDKError {
ret := _mock.Called(ctx, chanID, msg, domainID, key)
if len(ret) == 0 {
panic("no return value specified for SendMessage")
}
var r0 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok {
r0 = returnFunc(ctx, chanID, msg, key)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) errors.SDKError); ok {
r0 = returnFunc(ctx, chanID, msg, domainID, key)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(errors.SDKError)
@@ -5671,14 +5671,15 @@ type SDK_SendMessage_Call struct {
// - ctx
// - chanID
// - msg
// - domainID
// - key
func (_e *SDK_Expecter) SendMessage(ctx interface{}, chanID interface{}, msg interface{}, key interface{}) *SDK_SendMessage_Call {
return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, chanID, msg, key)}
func (_e *SDK_Expecter) SendMessage(ctx interface{}, chanID interface{}, msg interface{}, domainID interface{}, key interface{}) *SDK_SendMessage_Call {
return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, chanID, msg, domainID, key)}
}
func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, chanID string, msg string, key string)) *SDK_SendMessage_Call {
func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, chanID string, msg string, domainID string, key string)) *SDK_SendMessage_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string))
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string))
})
return _c
}
@@ -5688,7 +5689,7 @@ func (_c *SDK_SendMessage_Call) Return(sDKError errors.SDKError) *SDK_SendMessag
return _c
}
func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, chanID string, msg string, key string) errors.SDKError) *SDK_SendMessage_Call {
func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, chanID string, msg string, domainID string, key string) errors.SDKError) *SDK_SendMessage_Call {
_c.Call.Return(run)
return _c
}
+2 -2
View File
@@ -1015,9 +1015,9 @@ type SDK interface {
//
// example:
// msg := '[{"bn":"some-base-name:","bt":1.276020076001e+09, "bu":"A","bver":5, "n":"voltage","u":"V","v":120.1}, {"n":"current","t":-5,"v":1.2}, {"n":"current","t":-4,"v":1.3}]'
// err := sdk.SendMessage("channelID", msg, "clientSecret")
// err := sdk.SendMessage("channelID", msg, "domainID", "clientSecret")
// fmt.Println(err)
SendMessage(ctx context.Context, chanID, msg, key string) errors.SDKError
SendMessage(ctx context.Context, chanID, msg, domainID, key string) errors.SDKError
// SetContentType sets message content type.
//
+8 -6
View File
@@ -36,9 +36,10 @@ var (
// Service specifies web socket service API.
type Service interface {
// Subscribe subscribes message from the broker using the clientKey for authorization,
// and the channelID for subscription. Subtopic is optional.
// the channelID for subscription and domainID specifies the domain for authorization.
// Subtopic is optional.
// If the subscription is successful, nil is returned otherwise error is returned.
Subscribe(ctx context.Context, clientKey, chanID, subtopic string, client *Client) error
Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *Client) error
}
var _ Service = (*adapterService)(nil)
@@ -58,12 +59,12 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
}
}
func (svc *adapterService) Subscribe(ctx context.Context, clientKey, chanID, subtopic string, c *Client) error {
if chanID == "" || clientKey == "" {
func (svc *adapterService) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *Client) error {
if chanID == "" || clientKey == "" || domainID == "" {
return svcerr.ErrAuthentication
}
clientID, err := svc.authorize(ctx, clientKey, chanID, connections.Subscribe)
clientID, err := svc.authorize(ctx, clientKey, domainID, chanID, connections.Subscribe)
if err != nil {
return svcerr.ErrAuthorization
}
@@ -90,7 +91,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, clientKey, chanID, sub
// authorize checks if the clientKey is authorized to access the channel
// and returns the clientID if it is.
func (svc *adapterService) authorize(ctx context.Context, clientKey, chanID string, msgType connections.ConnType) (string, error) {
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: clientKey,
}
@@ -110,6 +111,7 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, chanID stri
ClientId: authnRes.GetId(),
Type: uint32(msgType),
ChannelId: chanID,
DomainId: domainID,
}
authzRes, err := svc.channels.Authorize(ctx, authzReq)
if err != nil {
+18 -3
View File
@@ -35,14 +35,16 @@ const (
)
var (
msg = messaging.Message{
domainID = testsutil.GenerateUUID(&testing.T{})
clientID = testsutil.GenerateUUID(&testing.T{})
msg = messaging.Message{
Channel: chanID,
Domain: domainID,
Publisher: id,
Subtopic: "",
Protocol: protocol,
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
}
clientID = testsutil.GenerateUUID(&testing.T{})
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
@@ -62,6 +64,7 @@ func TestSubscribe(t *testing.T) {
desc string
clientKey string
chanID string
domainID string
subtopic string
authNRes *grpcClientsV1.AuthnRes
authNErr error
@@ -74,6 +77,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -83,6 +87,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -92,6 +97,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with subscribe set to fail",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
subErr: ws.ErrFailedSubscription,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
@@ -102,6 +108,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with invalid clientKey",
clientKey: invalidKey,
chanID: invalidID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
@@ -111,6 +118,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with empty channel",
clientKey: clientKey,
chanID: "",
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
@@ -118,6 +126,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with empty clientKey",
clientKey: "",
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
@@ -125,6 +134,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with empty clientKey and empty channel",
clientKey: "",
chanID: "",
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
@@ -132,6 +142,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with invalid channel",
clientKey: clientKey,
chanID: invalidID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
@@ -142,6 +153,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with failed authentication",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
err: svcerr.ErrAuthorization,
@@ -150,6 +162,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with failed authorization",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
@@ -159,6 +172,7 @@ func TestSubscribe(t *testing.T) {
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
clientKey: "Client " + clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -183,9 +197,10 @@ func TestSubscribe(t *testing.T) {
ClientId: tc.authNRes.GetId(),
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
repocall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), tc.clientKey, tc.chanID, tc.subtopic, c)
err := svc.Subscribe(context.Background(), tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
clientsCall.Unset()
+23 -10
View File
@@ -18,6 +18,7 @@ import (
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
@@ -31,14 +32,16 @@ import (
)
const (
chanID = "30315311-56ba-484d-b500-c1e08305511f"
id = "1"
clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
protocol = "ws"
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
)
var msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`)
var (
msg = []byte(`[{"n":"current","t":-1,"v":1.6}]`)
domainID = testsutil.GenerateUUID(&testing.T{})
)
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
pubsub := new(mocks.PubSub)
@@ -59,15 +62,15 @@ func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*ht
return httptest.NewServer(http.HandlerFunc(mp.Handler)), nil
}
func makeURL(tsURL, chanID, subtopic, clientKey string, header bool) (string, error) {
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
u, _ := url.Parse(tsURL)
u.Scheme = protocol
if chanID == "0" || chanID == "" {
if header {
return fmt.Sprintf("%s/c/%s/m", u, chanID), fmt.Errorf("invalid channel id")
return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id")
}
return fmt.Sprintf("%s/c/%s/m?authorization=%s", u, chanID, clientKey), fmt.Errorf("invalid channel id")
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id")
}
subtopicPart := ""
@@ -75,19 +78,19 @@ func makeURL(tsURL, chanID, subtopic, clientKey string, header bool) (string, er
subtopicPart = fmt.Sprintf("/%s", subtopic)
}
if header {
return fmt.Sprintf("%s/c/%s/m%s", u, chanID, subtopicPart), nil
return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil
}
return fmt.Sprintf("%s/c/%s/m%s?authorization=%s", u, chanID, subtopicPart, clientKey), nil
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil
}
func handshake(tsURL, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
header := http.Header{}
if addHeader {
header.Add("Authorization", clientKey)
}
turl, _ := makeURL(tsURL, chanID, subtopic, clientKey, addHeader)
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader)
conn, res, errRet := websocket.DefaultDialer.Dial(turl, header)
return conn, res, errRet
@@ -112,6 +115,7 @@ func TestHandshake(t *testing.T) {
cases := []struct {
desc string
domainID string
chanID string
subtopic string
header bool
@@ -122,6 +126,7 @@ func TestHandshake(t *testing.T) {
}{
{
desc: "connect and send message",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
@@ -131,6 +136,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message with clientKey as query parameter",
domainID: domainID,
chanID: id,
subtopic: "",
header: false,
@@ -140,6 +146,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message that cannot be published",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
@@ -149,6 +156,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message to subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic",
header: true,
@@ -158,6 +166,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message to nested subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic/nested",
header: true,
@@ -167,6 +176,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message to all subtopics",
domainID: domainID,
chanID: id,
subtopic: ">",
header: true,
@@ -176,6 +186,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect to empty channel",
domainID: domainID,
chanID: "",
subtopic: "",
header: true,
@@ -185,6 +196,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect with empty clientKey",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
@@ -194,6 +206,7 @@ func TestHandshake(t *testing.T) {
},
{
desc: "connect and send message to subtopic with invalid name",
domainID: domainID,
chanID: id,
subtopic: "sub/a*b/topic",
header: true,
@@ -205,7 +218,7 @@ func TestHandshake(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
conn, res, err := handshake(ts.URL, tc.chanID, tc.subtopic, tc.clientKey, tc.header)
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header)
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
if tc.status == http.StatusSwitchingProtocols {
+6 -4
View File
@@ -16,7 +16,7 @@ import (
"github.com/go-chi/chi/v5"
)
var channelPartRegExp = regexp.MustCompile(`^/c/([\w\-]+)/m(/[^?]*)?(\?.*)?$`)
var channelPartRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
func handshake(ctx context.Context, svc ws.Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
@@ -33,7 +33,7 @@ func handshake(ctx context.Context, svc ws.Service) http.HandlerFunc {
req.conn = conn
client := ws.NewClient(conn)
if err := svc.Subscribe(ctx, req.clientKey, req.chanID, req.subtopic, client); err != nil {
if err := svc.Subscribe(ctx, req.clientKey, req.domainID, req.chanID, req.subtopic, client); err != nil {
req.conn.Close()
return
}
@@ -53,20 +53,22 @@ func decodeRequest(r *http.Request) (connReq, error) {
authKey = authKeys[0]
}
domainID := chi.URLParam(r, "domainID")
chanID := chi.URLParam(r, "chanID")
req := connReq{
clientKey: authKey,
chanID: chanID,
domainID: domainID,
}
channelParts := channelPartRegExp.FindStringSubmatch(r.RequestURI)
if len(channelParts) < 2 {
if len(channelParts) < 3 {
logger.Warn("Empty channel id or malformed url")
return connReq{}, errors.ErrMalformedEntity
}
subtopic, err := parseSubTopic(channelParts[2])
subtopic, err := parseSubTopic(channelParts[3])
if err != nil {
return connReq{}, err
}
+4 -3
View File
@@ -25,22 +25,23 @@ func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service {
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, chanID, subtopic string, c *ws.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("channel_id", chanID),
slog.String("domain_id", domainID),
}
if subtopic != "" {
args = append(args, "subtopic", subtopic)
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Subscibe failed", args...)
lm.logger.Warn("Subscribe failed", args...)
return
}
lm.logger.Info("Subscribe completed successfully", args...)
}(time.Now())
return lm.svc.Subscribe(ctx, clientKey, chanID, subtopic, c)
return lm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
}
+2 -2
View File
@@ -31,11 +31,11 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, clientKey, chanID, subtopic string, c *ws.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Subscribe(ctx, clientKey, chanID, subtopic, c)
return mm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, c)
}
+1
View File
@@ -8,6 +8,7 @@ import "github.com/gorilla/websocket"
type connReq struct {
clientKey string
chanID string
domainID string
subtopic string
conn *websocket.Conn
}
+2 -2
View File
@@ -40,8 +40,8 @@ func MakeHandler(ctx context.Context, svc ws.Service, l *slog.Logger, instanceID
logger = l
mux := chi.NewRouter()
mux.Get("/c/{chanID}/m", handshake(ctx, svc))
mux.Get("/c/{chanID}/m/*", handshake(ctx, svc))
mux.Get("/m/{domainID}/c/{chanID}", handshake(ctx, svc))
mux.Get("/m/{domainID}/c/{chanID}/*", handshake(ctx, svc))
mux.Get("/health", supermq.Health(service, instanceID))
mux.Handle("/metrics", promhttp.Handler())
+11 -7
View File
@@ -50,7 +50,7 @@ var (
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
)
var channelRegExp = regexp.MustCompile(`^\/?c\/([\w\-]+)\/m(\/[^?]*)?(\?.*)?$`)
var channelRegExp = regexp.MustCompile(`^\/?m\/([\w\-]+)\/c\/([\w\-]+)(\/[^?]*)?(\?.*)?$`)
// Event implements events.Event interface.
type handler struct {
@@ -139,14 +139,15 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
}
// Topics are in the format:
// c/<channel_id>/m/<subtopic>/.../ct/<content_type>
// m/<domain_id>/c/<channel_id>/<subtopic>/.../ct/<content_type>
channelParts := channelRegExp.FindStringSubmatch(*topic)
if len(channelParts) < 2 {
if len(channelParts) < 3 {
return errors.Wrap(errFailedPublish, errMalformedTopic)
}
chanID := channelParts[1]
subtopic := channelParts[2]
domainID := channelParts[1]
chanID := channelParts[2]
subtopic := channelParts[3]
subtopic, err := parseSubtopic(subtopic)
if err != nil {
@@ -160,6 +161,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: chanID,
Subtopic: subtopic,
Payload: *payload,
@@ -230,17 +232,19 @@ func (h *handler) authAccess(ctx context.Context, token, topic string, msgType c
}
channelParts := channelRegExp.FindStringSubmatch(topic)
if len(channelParts) < 1 {
if len(channelParts) < 3 {
return "", "", errMalformedTopic
}
chanID := channelParts[1]
domainID := channelParts[1]
chanID := channelParts[2]
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: clientID,
ClientType: clientType,
ChannelId: chanID,
DomainId: domainID,
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
+2 -2
View File
@@ -32,9 +32,9 @@ func New(tracer trace.Tracer, svc ws.Service) ws.Service {
}
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
func (tm *tracingMiddleware) Subscribe(ctx context.Context, clientKey, chanID, subtopic string, client *ws.Client) error {
func (tm *tracingMiddleware) Subscribe(ctx context.Context, clientKey, domainID, chanID, subtopic string, client *ws.Client) error {
ctx, span := tm.tracer.Start(ctx, subscribeOP)
defer span.End()
return tm.svc.Subscribe(ctx, clientKey, chanID, subtopic, client)
return tm.svc.Subscribe(ctx, clientKey, domainID, chanID, subtopic, client)
}