SMQ-3073 - Enable user messaging through WS and HTTP protocols (#3075)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-08-21 17:10:34 +03:00
committed by GitHub
parent 4927333691
commit c63c936b36
14 changed files with 914 additions and 288 deletions
+4 -3
View File
@@ -20,6 +20,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authn/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
@@ -209,7 +210,7 @@ func main() {
return
}
svc := newService(clientsClient, channelsClient, nps, logger, tracer)
svc := newService(clientsClient, channelsClient, authn, nps, logger, tracer)
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
@@ -236,8 +237,8 @@ func main() {
}
}
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
svc := ws.New(clientsClient, channels, nps)
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn authn.Authentication, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
svc := ws.New(clientsClient, channels, authn, nps)
svc = tracing.New(tracer, svc)
svc = httpapi.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics("ws_adapter", "api")
+62 -6
View File
@@ -31,6 +31,7 @@ import (
smqauthn "github.com/absmach/supermq/pkg/authn"
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/connections"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
pubsub "github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
@@ -47,6 +48,7 @@ var (
clientID = testsutil.GenerateUUID(&testing.T{})
chanID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
userID = testsutil.GenerateUUID(&testing.T{})
)
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) {
@@ -91,6 +93,7 @@ type testRequest struct {
token string
body io.Reader
basicAuth bool
bearerToken bool
}
func (tr testRequest) make() (*http.Response, error) {
@@ -100,10 +103,14 @@ func (tr testRequest) make() (*http.Response, error) {
}
if tr.token != "" {
req.Header.Set("Authorization", apiutil.ClientPrefix+tr.token)
}
if tr.basicAuth && tr.token != "" {
req.SetBasicAuth("", apiutil.ClientPrefix+tr.token)
switch {
case tr.basicAuth:
req.SetBasicAuth("", apiutil.ClientPrefix+tr.token)
case tr.bearerToken:
req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token)
default:
req.Header.Set("Authorization", apiutil.ClientPrefix+tr.token)
}
}
if tr.contentType != "" {
req.Header.Set("Content-Type", tr.contentType)
@@ -121,6 +128,8 @@ func TestPublish(t *testing.T) {
ctJSON := "application/json"
clientKey := "client_key"
invalidKey := invalidValue
validToken := "token"
invalidToken := "invalid_token"
msg := `[{"n":"current","t":-1,"v":1.6}]`
msgJSON := `{"field1":"val1","field2":"val2"}`
msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A`
@@ -137,13 +146,17 @@ func TestPublish(t *testing.T) {
desc string
domainID string
chanID string
clientID string
clientType string
msg string
contentType string
key string
status int
basicAuth bool
bearerToken bool
authnErr error
authnRes *grpcClientsV1.AuthnRes
authnRes1 smqauthn.Session
authzRes *grpcChannelsV1.AuthzRes
authzErr error
err error
@@ -152,6 +165,7 @@ func TestPublish(t *testing.T) {
desc: "publish message successfully",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: clientKey,
@@ -163,6 +177,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with application/senml+cbor content-type",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msgCBOR,
contentType: ctSenmlCBOR,
key: clientKey,
@@ -174,6 +189,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with application/json content-type",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msgJSON,
contentType: ctJSON,
key: clientKey,
@@ -185,6 +201,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with empty key",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: "",
@@ -194,6 +211,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with basic auth",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: clientKey,
@@ -206,6 +224,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with invalid key",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: invalidKey,
@@ -216,6 +235,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with invalid basic auth",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: invalidKey,
@@ -223,10 +243,37 @@ func TestPublish(t *testing.T) {
status: http.StatusUnauthorized,
authnRes: &grpcClientsV1.AuthnRes{Authenticated: false},
},
{
desc: "publish message with valid bearer token",
domainID: domainID,
chanID: chanID,
clientID: userID,
msg: msg,
contentType: ctSenmlJSON,
key: validToken,
bearerToken: true,
status: http.StatusAccepted,
authnRes1: smqauthn.Session{UserID: userID},
authzRes: &grpcChannelsV1.AuthzRes{Authorized: true},
},
{
desc: "publish message with invalid bearer token",
domainID: domainID,
chanID: chanID,
clientID: userID,
msg: msg,
contentType: ctSenmlJSON,
key: invalidToken,
bearerToken: true,
status: http.StatusUnauthorized,
authnRes1: smqauthn.Session{},
authnErr: svcerr.ErrAuthentication,
},
{
desc: "publish message without content type",
domainID: domainID,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: "",
key: clientKey,
@@ -238,6 +285,7 @@ func TestPublish(t *testing.T) {
desc: "publish message to empty channel",
domainID: domainID,
chanID: "",
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: clientKey,
@@ -249,6 +297,7 @@ func TestPublish(t *testing.T) {
desc: "publish message with invalid domain ID",
domainID: invalidValue,
chanID: chanID,
clientID: clientID,
msg: msg,
contentType: ctSenmlJSON,
key: clientKey,
@@ -261,12 +310,17 @@ func TestPublish(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr)
authCall := authn.On("Authenticate", mock.Anything, tc.key).Return(tc.authnRes1, tc.authnErr)
domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
tc.clientType = policies.ClientType
if tc.bearerToken {
tc.clientType = policies.UserType
}
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
DomainId: tc.domainID,
ChannelId: tc.chanID,
ClientId: clientID,
ClientType: policies.ClientType,
ClientId: tc.clientID,
ClientType: tc.clientType,
Type: uint32(connections.Publish),
}).Return(tc.authzRes, tc.authzErr)
svcCall := pub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil)
@@ -278,12 +332,14 @@ func TestPublish(t *testing.T) {
token: tc.key,
body: strings.NewReader(tc.msg),
basicAuth: tc.basicAuth,
bearerToken: tc.bearerToken,
}
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
clientsCall.Unset()
authCall.Unset()
channelsCall.Unset()
domainsCall.Unset()
})
+5 -4
View File
@@ -21,9 +21,10 @@ import (
)
const (
ctSenmlJSON = "application/senml+json"
ctSenmlCBOR = "application/senml+cbor"
contentType = "application/json"
ctSenmlJSON = "application/senml+json"
ctSenmlCBOR = "application/senml+cbor"
contentType = "application/json"
authzHeaderKey = "Authorization"
)
// MakeHandler returns a HTTP handler for API endpoints.
@@ -64,7 +65,7 @@ func decodeRequest(_ context.Context, r *http.Request) (interface{}, error) {
case ok:
req.token = pass
case !ok:
req.token = apiutil.ExtractClientSecret(r)
req.token = r.Header.Get(authzHeaderKey)
}
payload, err := io.ReadAll(r.Body)
+9 -20
View File
@@ -36,7 +36,6 @@ const (
// Log message formats.
const (
logInfoConnected = "connected with client_key %s"
logInfoPublished = "published with client_type %s client_id %s to the topic %s"
logInfoFailedAuthNToken = "failed to authenticate token for topic %s with error %s"
logInfoFailedAuthNClient = "failed to authenticate client key %s for topic %s with error %s"
@@ -81,17 +80,10 @@ func (h *handler) AuthConnect(ctx context.Context) error {
return errClientNotInitialized
}
var tok string
switch {
case string(s.Password) == "":
if string(s.Password) == "" {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey))
case strings.HasPrefix(string(s.Password), apiutil.ClientPrefix):
tok = strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
default:
tok = string(s.Password)
}
h.logger.Info(fmt.Sprintf(logInfoConnected, tok))
return nil
}
@@ -149,18 +141,19 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
clientType = policies.UserType
clientID = authnSession.DomainUserID
clientID = authnSession.UserID
default:
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
msg := messaging.Message{
Protocol: protocol,
Domain: domainID,
Channel: channelID,
Subtopic: subtopic,
Payload: *payload,
Created: time.Now().UnixNano(),
Protocol: protocol,
Domain: domainID,
Channel: channelID,
Subtopic: subtopic,
Publisher: clientID,
Payload: *payload,
Created: time.Now().UnixNano(),
}
ar := &grpcChannelsV1.AuthzReq{
@@ -178,10 +171,6 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization)
}
if clientType == policies.ClientType {
msg.Publisher = clientID
}
if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
return errors.Wrap(errFailedPublishToMsgBroker, err)
}
+16
View File
@@ -54,6 +54,7 @@ var (
Password: []byte(clientKey),
}
validToken = "token"
invalidToken = "invalid_token"
validID = testsutil.GenerateUUID(&testing.T{})
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
@@ -148,6 +149,9 @@ func TestPublish(t *testing.T) {
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
cases := []struct {
desc string
topic *string
@@ -190,6 +194,18 @@ func TestPublish(t *testing.T) {
authZErr: nil,
err: nil,
},
{
desc: "publish with invalid token",
topic: &topic,
payload: &payload,
password: invalidToken,
session: &invalidTokenSession,
channelID: chanID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with key and subtopic successfully",
topic: &subtopic,
+32 -17
View File
@@ -9,7 +9,8 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/pkg/authn"
apiutil "github.com/absmach/supermq/api/http/util"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -42,14 +43,16 @@ var _ Service = (*adapterService)(nil)
type adapterService struct {
clients grpcClientsV1.ClientsServiceClient
channels grpcChannelsV1.ChannelsServiceClient
authn smqauthn.Authentication
pubsub messaging.PubSub
}
// New instantiates the WS adapter implementation.
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service {
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub messaging.PubSub) Service {
return &adapterService{
clients: clients,
channels: channels,
authn: authn,
pubsub: pubsub,
}
}
@@ -89,23 +92,35 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
return nil
}
// authorize checks if the clientKey is authorized to access the channel
// and returns the clientID if it is.
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
if strings.HasPrefix(clientKey, "Client") {
clientKey = extractClientSecret(clientKey)
}
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, clientKey)})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.GetAuthenticated() {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
// authorize checks if the authKey is authorized to access the channel
// and returns the clientID or userID if it is.
func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
var clientID, clientType string
switch {
case strings.HasPrefix(authKey, apiutil.BearerPrefix):
token := strings.TrimPrefix(authKey, apiutil.BearerPrefix)
authnSession, err := svc.authn.Authenticate(ctx, token)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
clientType = policies.UserType
clientID = authnSession.UserID
default:
secret := strings.TrimPrefix(authKey, apiutil.ClientPrefix)
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
clientType = policies.ClientType
clientID = authnRes.GetId()
}
authzReq := &grpcChannelsV1.AuthzReq{
ClientType: policies.ClientType,
ClientId: authnRes.GetId(),
ClientType: clientType,
ClientId: clientID,
Type: uint32(msgType),
ChannelId: chanID,
DomainId: domainID,
@@ -118,5 +133,5 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
return authnRes.GetId(), nil
return clientID, nil
}
+153 -106
View File
@@ -12,10 +12,12 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -28,18 +30,21 @@ import (
)
const (
chanID = "1"
invalidID = "invalidID"
invalidKey = "invalidKey"
id = "1"
clientKey = "client_key"
subTopic = "subtopic"
protocol = "ws"
invalidID = "invalidID"
invalidKey = "invalidKey"
id = "1"
clientKey = "client_key"
subTopic = "subtopic"
protocol = "ws"
token = "Bearer token"
invalidToken = "Bearer invalid_token"
)
var (
domainID = testsutil.GenerateUUID(&testing.T{})
clientID = testsutil.GenerateUUID(&testing.T{})
userID = testsutil.GenerateUUID(&testing.T{})
chanID = testsutil.GenerateUUID(&testing.T{})
msg = messaging.Message{
Channel: chanID,
Domain: domainID,
@@ -51,136 +56,172 @@ var (
sessionID = "sessionID"
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
pubsub := new(mocks.PubSub)
clients := new(climocks.ClientsServiceClient)
channels := new(chmocks.ChannelsServiceClient)
authn := new(authnmocks.Authentication)
return ws.New(clients, channels, pubsub), pubsub, clients, channels
return ws.New(clients, channels, authn, pubsub), pubsub, clients, channels, authn
}
func TestSubscribe(t *testing.T) {
svc, pubsub, clients, channels := newService()
svc, pubsub, clients, channels, auth := newService()
c := ws.NewClient(slog.Default(), nil, sessionID)
cases := []struct {
desc string
clientKey string
chanID string
domainID string
subtopic string
authNRes *grpcClientsV1.AuthnRes
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
subErr error
err error
desc string
authKey string
chanID string
domainID string
subtopic string
clientType string
clientID string
authNRes *grpcClientsV1.AuthnRes
authNErr error
authNRes1 smqauthn.Session
authZRes *grpcChannelsV1.AuthzRes
authZErr error
subErr error
err error
}{
{
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
clientKey: clientKey,
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with valid token, chanID, subtopic",
authKey: token,
chanID: chanID,
domainID: domainID,
clientID: userID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNRes1: smqauthn.Session{UserID: userID},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
clientKey: clientKey,
desc: "subscribe to channel with invalid token",
authKey: invalidToken,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with subscribe set to fail",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
subErr: ws.ErrFailedSubscription,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: ws.ErrFailedSubscription,
},
{
desc: "subscribe to channel with invalid clientKey",
clientKey: invalidKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with empty channel",
clientKey: clientKey,
chanID: "",
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with empty clientKey",
clientKey: "",
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
desc: "subscribe to channel with subscribe set to fail",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
subErr: ws.ErrFailedSubscription,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: ws.ErrFailedSubscription,
},
{
desc: "subscribe to channel with empty clientKey and empty channel",
clientKey: "",
chanID: "",
domainID: domainID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
desc: "subscribe to channel with invalid clientKey",
authKey: invalidKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with invalid channel",
clientKey: clientKey,
chanID: invalidID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with empty channel",
authKey: clientKey,
chanID: "",
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe to channel with failed authentication",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with empty clientKey",
authKey: "",
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe to channel with failed authorization",
clientKey: clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with empty clientKey and empty channel",
authKey: "",
chanID: "",
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
clientKey: "Client " + clientKey,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
desc: "subscribe to channel with invalid channel",
authKey: clientKey,
chanID: invalidID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with failed authentication",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with failed authorization",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
authKey: "Client " + clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
}
@@ -188,26 +229,32 @@ func TestSubscribe(t *testing.T) {
subConfig := messaging.SubscriberConfig{
ID: sessionID,
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
ClientID: clientID,
ClientID: tc.clientID,
Handler: c,
}
authReq := &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, tc.domainID, tc.clientKey)}
if strings.HasPrefix(tc.clientKey, "Client") {
authReq.Token = authn.AuthPack(authn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.clientKey, "Client "))
authReq := &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.authKey)}
tc.clientType = policies.ClientType
if strings.HasPrefix(tc.authKey, "Client") {
authReq.Token = smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.authKey, "Client "))
}
if strings.HasPrefix(tc.authKey, apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.authKey, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: policies.ClientType,
ClientId: tc.authNRes.GetId(),
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
err := svc.Subscribe(context.Background(), sessionID, tc.authKey, tc.domainID, tc.chanID, tc.subtopic, c)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
clientsCall.Unset()
authCall.Unset()
channelsCall.Unset()
}
}
+101 -89
View File
@@ -39,6 +39,7 @@ const (
clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
protocol = "ws"
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
validToken = "Bearer token"
)
var (
@@ -49,7 +50,8 @@ var (
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
pubsub := new(mocks.PubSub)
return ws.New(clients, channels, pubsub), pubsub
authn := new(authnMocks.Authentication)
return ws.New(clients, channels, authn, pubsub), pubsub
}
func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server {
@@ -77,7 +79,7 @@ func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*ht
return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil
}
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
func makeURL(tsURL, domainID, chanID, subtopic, authKey string, header bool) (string, error) {
u, _ := url.Parse(tsURL)
u.Scheme = protocol
@@ -85,7 +87,7 @@ func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (
if header {
return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id")
}
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id")
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, authKey), fmt.Errorf("invalid channel id")
}
subtopicPart := ""
@@ -96,16 +98,16 @@ func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (
return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil
}
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, authKey), nil
}
func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
func handshake(tsURL, domainID, chanID, subtopic, authKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
header := http.Header{}
if addHeader {
header.Add("Authorization", clientKey)
header.Add("Authorization", authKey)
}
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader)
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, authKey, addHeader)
conn, res, errRet := websocket.DefaultDialer.Dial(turl, header)
return conn, res, errRet
@@ -138,111 +140,121 @@ func TestHandshake(t *testing.T) {
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
cases := []struct {
desc string
domainID string
chanID string
subtopic string
header bool
clientKey string
status int
err error
msg []byte
desc string
domainID string
chanID string
subtopic string
header bool
authKey string
status int
err error
msg []byte
}{
{
desc: "connect and send message",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
desc: "connect and send message",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message with clientKey as query parameter",
domainID: domainID,
chanID: id,
subtopic: "",
header: false,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
desc: "connect and send message with clientKey as query parameter",
domainID: domainID,
chanID: id,
subtopic: "",
header: false,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message that cannot be published",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: []byte{},
desc: "connect and send message with valid token",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: validToken,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
desc: "connect and send message that cannot be published",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: []byte{},
},
{
desc: "connect and send message to nested subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic/nested",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
desc: "connect and send message to subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect and send message to all subtopics",
domainID: domainID,
chanID: id,
subtopic: ">",
header: true,
clientKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
desc: "connect and send message to nested subtopic",
domainID: domainID,
chanID: id,
subtopic: "subtopic/nested",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect to empty channel",
domainID: domainID,
chanID: "",
subtopic: "",
header: true,
clientKey: clientKey,
status: http.StatusUnauthorized,
msg: []byte{},
desc: "connect and send message to all subtopics",
domainID: domainID,
chanID: id,
subtopic: ">",
header: true,
authKey: clientKey,
status: http.StatusSwitchingProtocols,
msg: msg,
},
{
desc: "connect with empty clientKey",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
clientKey: "",
status: http.StatusUnauthorized,
msg: []byte{},
desc: "connect to empty channel",
domainID: domainID,
chanID: "",
subtopic: "",
header: true,
authKey: clientKey,
status: http.StatusUnauthorized,
msg: []byte{},
},
{
desc: "connect and send message to subtopic with invalid name",
domainID: domainID,
chanID: id,
subtopic: "sub/a*b/topic",
header: true,
clientKey: clientKey,
status: http.StatusUnauthorized,
msg: msg,
desc: "connect with empty clientKey",
domainID: domainID,
chanID: id,
subtopic: "",
header: true,
authKey: "",
status: http.StatusUnauthorized,
msg: []byte{},
},
{
desc: "connect and send message to subtopic with invalid name",
domainID: domainID,
chanID: id,
subtopic: "sub/a*b/topic",
header: true,
authKey: clientKey,
status: http.StatusUnauthorized,
msg: msg,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header)
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.authKey, tc.header)
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
if tc.status == http.StatusSwitchingProtocols {
+2 -2
View File
@@ -56,7 +56,7 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
if err := svc.Subscribe(ctx, sessionID, req.authKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
conn.Close()
return
}
@@ -85,7 +85,7 @@ func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *sl
}
req := connReq{
clientKey: authKey,
authKey: authKey,
channelID: channelID,
domainID: domainID,
}
+2 -2
View File
@@ -25,7 +25,7 @@ func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service {
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -44,7 +44,7 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey
lm.logger.Info("Subscribe completed successfully", args...)
}(time.Now())
return lm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
return lm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, c)
}
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) (err error) {
+2 -2
View File
@@ -31,13 +31,13 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, c *ws.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
return mm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, c)
}
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
+1 -1
View File
@@ -4,7 +4,7 @@
package api
type connReq struct {
clientKey string
authKey string
channelID string
domainID string
subtopic string
+28 -36
View File
@@ -16,7 +16,6 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -85,25 +84,17 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
return errClientNotInitialized
}
var token string
switch {
case strings.HasPrefix(string(s.Password), "Client"):
token = strings.ReplaceAll(string(s.Password), "Client ", "")
default:
token = string(s.Password)
}
domainID, channelID, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
clientID, clientType, err := h.authAccess(ctx, token, domainID, channelID, connections.Publish)
clientID, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish)
if err != nil {
return err
}
if s.Username == "" && clientType == policies.ClientType {
if s.Username == "" {
s.Username = clientID
}
@@ -126,7 +117,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
if err != nil {
return err
}
if _, _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil {
if _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil {
return err
}
}
@@ -194,19 +185,29 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) {
if strings.HasPrefix(token, "Client") {
token = extractClientSecret(token)
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, error) {
var clientID, clientType string
switch {
case strings.HasPrefix(token, apiutil.BearerPrefix):
token := strings.TrimPrefix(token, apiutil.BearerPrefix)
authnSession, err := h.authn.Authenticate(ctx, token)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
clientType = policies.UserType
clientID = authnSession.UserID
default:
secret := strings.TrimPrefix(token, apiutil.ClientPrefix)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
if !authnRes.Authenticated {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
clientType = policies.ClientType
clientID = authnRes.GetId()
}
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, token)})
if err != nil {
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !authnRes.GetAuthenticated() {
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
clientType := policies.ClientType
clientID := authnRes.GetId()
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
@@ -217,20 +218,11 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
if !res.GetAuthorized() {
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
return clientID, clientType, nil
}
// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned.
func extractClientSecret(token string) string {
if !strings.HasPrefix(token, apiutil.ClientPrefix) {
return ""
}
return strings.TrimPrefix(token, apiutil.ClientPrefix)
return clientID, nil
}
+497
View File
@@ -0,0 +1,497 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package ws_test
import (
"context"
"fmt"
"net/http"
"strings"
"testing"
mgate "github.com/absmach/mgate/pkg/http"
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
chmocks "github.com/absmach/supermq/channels/mocks"
clmocks "github.com/absmach/supermq/clients/mocks"
dmocks "github.com/absmach/supermq/domains/mocks"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/messaging/mocks"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/ws"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
invalidValue = "invalid"
topicMsg = "/m/%s/c/%s"
subtopicMsg = "/m/%s/c/%s/subtopic"
topic = fmt.Sprintf(topicMsg, domainID, chanID)
subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID)
invalidTopic = invalidValue
topics = []string{topic}
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
sessionClient = session.Session{
ID: clientID,
Password: []byte(clientKey),
}
invalidChannelIDTopic = "m/**/c"
validToken = "token"
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
)
var (
clients = new(clmocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
authn = new(authnmocks.Authentication)
publisher = new(mocks.PubSub)
domains = new(dmocks.DomainsServiceClient)
)
func newHandler(t *testing.T) session.Handler {
logger := smqlog.NewMock()
authn = new(authnmocks.Authentication)
clients = new(clmocks.ClientsServiceClient)
channels = new(chmocks.ChannelsServiceClient)
publisher = new(mocks.PubSub)
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err))
return ws.NewHandler(publisher, logger, authn, clients, channels, parser)
}
func TestAuthPublish(t *testing.T) {
handler := newHandler(t)
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
tests := []struct {
desc string
session *session.Session
topic *string
payload *[]byte
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
err error
}{
{
desc: "publish with client key successfully",
session: &clientKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid client key",
session: &invalidClientKeySession,
topic: &topic,
authKey: invalidKey,
payload: &payload,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with nil session",
session: nil,
topic: &topic,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "publish with empty topic",
session: &clientKeySession,
topic: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicPub,
},
{
desc: "publish with unauthorized client key",
session: &clientKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with token successfully",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid token",
session: &invalidTokenSession,
topic: &topic,
authKey: invalidToken,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with unauthorized client key",
session: &tokenSession,
topic: &topic,
authKey: token,
payload: &payload,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Publish),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthPublish(ctx, tc.topic, tc.payload)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err))
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestAuthSubscribe(t *testing.T) {
handler := newHandler(t)
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
tests := []struct {
desc string
session *session.Session
topics *[]string
authKey string
status int
clientType string
chanID string
domainID string
clientID string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
authZRes *grpcChannelsV1.AuthzRes
authZErr error
err error
}{
{
desc: "subscribe with client key successfully",
session: &clientKeySession,
topics: &topics,
authKey: clientKey,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid client key",
session: &invalidClientKeySession,
topics: &topics,
authKey: invalidKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with empty topics",
session: &clientKeySession,
topics: nil,
authKey: clientKey,
status: http.StatusBadRequest,
err: errMissingTopicSub,
},
{
desc: "subscribe with nil session",
session: nil,
topics: &topics,
authKey: clientKey,
status: http.StatusInternalServerError,
err: errClientNotInitialized,
},
{
desc: "subscribe with unauthorized client key",
session: &clientKeySession,
topics: &topics,
authKey: clientKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with token successfully",
session: &tokenSession,
topics: &topics,
authKey: token,
status: http.StatusOK,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid token",
session: &invalidTokenSession,
topics: &topics,
authKey: invalidToken,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with unauthorized client key",
session: &tokenSession,
topics: &topics,
authKey: token,
clientType: policies.UserType,
chanID: chanID,
domainID: domainID,
clientID: userID,
authNRes1: smqauthn.Session{UserID: userID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
}
for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
tc.clientType = policies.ClientType
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
err := handler.AuthSubscribe(ctx, tc.topics)
hpe, ok := err.(mgate.HTTPProxyError)
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err))
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
})
}
}
func TestPublish(t *testing.T) {
handler := newHandler(t)
malformedSubtopics := topic + "/" + subtopic + "%"
wrongCharSubtopics := topic + "/" + subtopic + ">"
validSubtopic := topic + "/" + subtopic
cases := []struct {
desc string
session *session.Session
topic string
payload []byte
err error
}{
{
desc: "publish without active session",
session: nil,
topic: topic,
payload: payload,
err: errClientNotInitialized,
},
{
desc: "publish with invalid topic",
session: &sessionClient,
topic: invalidTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with invalid channel ID",
session: &sessionClient,
topic: invalidChannelIDTopic,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with malformed subtopic",
session: &sessionClient,
topic: malformedSubtopics,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with subtopic containing wrong character",
session: &sessionClient,
topic: wrongCharSubtopics,
payload: payload,
err: messaging.ErrMalformedTopic,
},
{
desc: "publish with subtopic",
session: &sessionClient,
topic: validSubtopic,
payload: payload,
},
{
desc: "publish without subtopic",
session: &sessionClient,
topic: topic,
payload: payload,
},
}
for _, tc := range cases {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
assert.True(t, errors.Contains(err, tc.err))
repoCall.Unset()
}
}