mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-3073 - Enable user messaging through WS and HTTP protocols (#3075)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
+4
-3
@@ -20,6 +20,7 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
@@ -209,7 +210,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
svc := newService(clientsClient, channelsClient, nps, logger, tracer)
|
||||
svc := newService(clientsClient, channelsClient, authn, nps, logger, tracer)
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, targetServerConfig, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger)
|
||||
|
||||
@@ -236,8 +237,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
|
||||
svc := ws.New(clientsClient, channels, nps)
|
||||
func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn authn.Authentication, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) ws.Service {
|
||||
svc := ws.New(clientsClient, channels, authn, nps)
|
||||
svc = tracing.New(tracer, svc)
|
||||
svc = httpapi.LoggingMiddleware(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics("ws_adapter", "api")
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnMocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
pubsub "github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
@@ -47,6 +48,7 @@ var (
|
||||
clientID = testsutil.GenerateUUID(&testing.T{})
|
||||
chanID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
)
|
||||
|
||||
func newService(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) {
|
||||
@@ -91,6 +93,7 @@ type testRequest struct {
|
||||
token string
|
||||
body io.Reader
|
||||
basicAuth bool
|
||||
bearerToken bool
|
||||
}
|
||||
|
||||
func (tr testRequest) make() (*http.Response, error) {
|
||||
@@ -100,10 +103,14 @@ func (tr testRequest) make() (*http.Response, error) {
|
||||
}
|
||||
|
||||
if tr.token != "" {
|
||||
req.Header.Set("Authorization", apiutil.ClientPrefix+tr.token)
|
||||
}
|
||||
if tr.basicAuth && tr.token != "" {
|
||||
req.SetBasicAuth("", apiutil.ClientPrefix+tr.token)
|
||||
switch {
|
||||
case tr.basicAuth:
|
||||
req.SetBasicAuth("", apiutil.ClientPrefix+tr.token)
|
||||
case tr.bearerToken:
|
||||
req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token)
|
||||
default:
|
||||
req.Header.Set("Authorization", apiutil.ClientPrefix+tr.token)
|
||||
}
|
||||
}
|
||||
if tr.contentType != "" {
|
||||
req.Header.Set("Content-Type", tr.contentType)
|
||||
@@ -121,6 +128,8 @@ func TestPublish(t *testing.T) {
|
||||
ctJSON := "application/json"
|
||||
clientKey := "client_key"
|
||||
invalidKey := invalidValue
|
||||
validToken := "token"
|
||||
invalidToken := "invalid_token"
|
||||
msg := `[{"n":"current","t":-1,"v":1.6}]`
|
||||
msgJSON := `{"field1":"val1","field2":"val2"}`
|
||||
msgCBOR := `81A3616E6763757272656E746174206176FB3FF999999999999A`
|
||||
@@ -137,13 +146,17 @@ func TestPublish(t *testing.T) {
|
||||
desc string
|
||||
domainID string
|
||||
chanID string
|
||||
clientID string
|
||||
clientType string
|
||||
msg string
|
||||
contentType string
|
||||
key string
|
||||
status int
|
||||
basicAuth bool
|
||||
bearerToken bool
|
||||
authnErr error
|
||||
authnRes *grpcClientsV1.AuthnRes
|
||||
authnRes1 smqauthn.Session
|
||||
authzRes *grpcChannelsV1.AuthzRes
|
||||
authzErr error
|
||||
err error
|
||||
@@ -152,6 +165,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message successfully",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: clientKey,
|
||||
@@ -163,6 +177,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with application/senml+cbor content-type",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msgCBOR,
|
||||
contentType: ctSenmlCBOR,
|
||||
key: clientKey,
|
||||
@@ -174,6 +189,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with application/json content-type",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msgJSON,
|
||||
contentType: ctJSON,
|
||||
key: clientKey,
|
||||
@@ -185,6 +201,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with empty key",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: "",
|
||||
@@ -194,6 +211,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with basic auth",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: clientKey,
|
||||
@@ -206,6 +224,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with invalid key",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: invalidKey,
|
||||
@@ -216,6 +235,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with invalid basic auth",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: invalidKey,
|
||||
@@ -223,10 +243,37 @@ func TestPublish(t *testing.T) {
|
||||
status: http.StatusUnauthorized,
|
||||
authnRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
},
|
||||
{
|
||||
desc: "publish message with valid bearer token",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: userID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: validToken,
|
||||
bearerToken: true,
|
||||
status: http.StatusAccepted,
|
||||
authnRes1: smqauthn.Session{UserID: userID},
|
||||
authzRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
},
|
||||
{
|
||||
desc: "publish message with invalid bearer token",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: userID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: invalidToken,
|
||||
bearerToken: true,
|
||||
status: http.StatusUnauthorized,
|
||||
authnRes1: smqauthn.Session{},
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish message without content type",
|
||||
domainID: domainID,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: "",
|
||||
key: clientKey,
|
||||
@@ -238,6 +285,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message to empty channel",
|
||||
domainID: domainID,
|
||||
chanID: "",
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: clientKey,
|
||||
@@ -249,6 +297,7 @@ func TestPublish(t *testing.T) {
|
||||
desc: "publish message with invalid domain ID",
|
||||
domainID: invalidValue,
|
||||
chanID: chanID,
|
||||
clientID: clientID,
|
||||
msg: msg,
|
||||
contentType: ctSenmlJSON,
|
||||
key: clientKey,
|
||||
@@ -261,12 +310,17 @@ func TestPublish(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr)
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.key).Return(tc.authnRes1, tc.authnErr)
|
||||
domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
|
||||
tc.clientType = policies.ClientType
|
||||
if tc.bearerToken {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
DomainId: tc.domainID,
|
||||
ChannelId: tc.chanID,
|
||||
ClientId: clientID,
|
||||
ClientType: policies.ClientType,
|
||||
ClientId: tc.clientID,
|
||||
ClientType: tc.clientType,
|
||||
Type: uint32(connections.Publish),
|
||||
}).Return(tc.authzRes, tc.authzErr)
|
||||
svcCall := pub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil)
|
||||
@@ -278,12 +332,14 @@ func TestPublish(t *testing.T) {
|
||||
token: tc.key,
|
||||
body: strings.NewReader(tc.msg),
|
||||
basicAuth: tc.basicAuth,
|
||||
bearerToken: tc.bearerToken,
|
||||
}
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
clientsCall.Unset()
|
||||
authCall.Unset()
|
||||
channelsCall.Unset()
|
||||
domainsCall.Unset()
|
||||
})
|
||||
|
||||
@@ -21,9 +21,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
ctSenmlJSON = "application/senml+json"
|
||||
ctSenmlCBOR = "application/senml+cbor"
|
||||
contentType = "application/json"
|
||||
ctSenmlJSON = "application/senml+json"
|
||||
ctSenmlCBOR = "application/senml+cbor"
|
||||
contentType = "application/json"
|
||||
authzHeaderKey = "Authorization"
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
@@ -64,7 +65,7 @@ func decodeRequest(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
case ok:
|
||||
req.token = pass
|
||||
case !ok:
|
||||
req.token = apiutil.ExtractClientSecret(r)
|
||||
req.token = r.Header.Get(authzHeaderKey)
|
||||
}
|
||||
|
||||
payload, err := io.ReadAll(r.Body)
|
||||
|
||||
+9
-20
@@ -36,7 +36,6 @@ const (
|
||||
|
||||
// Log message formats.
|
||||
const (
|
||||
logInfoConnected = "connected with client_key %s"
|
||||
logInfoPublished = "published with client_type %s client_id %s to the topic %s"
|
||||
logInfoFailedAuthNToken = "failed to authenticate token for topic %s with error %s"
|
||||
logInfoFailedAuthNClient = "failed to authenticate client key %s for topic %s with error %s"
|
||||
@@ -81,17 +80,10 @@ func (h *handler) AuthConnect(ctx context.Context) error {
|
||||
return errClientNotInitialized
|
||||
}
|
||||
|
||||
var tok string
|
||||
switch {
|
||||
case string(s.Password) == "":
|
||||
if string(s.Password) == "" {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey))
|
||||
case strings.HasPrefix(string(s.Password), apiutil.ClientPrefix):
|
||||
tok = strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
|
||||
default:
|
||||
tok = string(s.Password)
|
||||
}
|
||||
|
||||
h.logger.Info(fmt.Sprintf(logInfoConnected, tok))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -149,18 +141,19 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.DomainUserID
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
|
||||
msg := messaging.Message{
|
||||
Protocol: protocol,
|
||||
Domain: domainID,
|
||||
Channel: channelID,
|
||||
Subtopic: subtopic,
|
||||
Payload: *payload,
|
||||
Created: time.Now().UnixNano(),
|
||||
Protocol: protocol,
|
||||
Domain: domainID,
|
||||
Channel: channelID,
|
||||
Subtopic: subtopic,
|
||||
Publisher: clientID,
|
||||
Payload: *payload,
|
||||
Created: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
ar := &grpcChannelsV1.AuthzReq{
|
||||
@@ -178,10 +171,6 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization)
|
||||
}
|
||||
|
||||
if clientType == policies.ClientType {
|
||||
msg.Publisher = clientID
|
||||
}
|
||||
|
||||
if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil {
|
||||
return errors.Wrap(errFailedPublishToMsgBroker, err)
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ var (
|
||||
Password: []byte(clientKey),
|
||||
}
|
||||
validToken = "token"
|
||||
invalidToken = "invalid_token"
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
errClientNotInitialized = errors.New("client is not initialized")
|
||||
errMissingTopicPub = errors.New("failed to publish due to missing topic")
|
||||
@@ -148,6 +149,9 @@ func TestPublish(t *testing.T) {
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic *string
|
||||
@@ -190,6 +194,18 @@ func TestPublish(t *testing.T) {
|
||||
authZErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid token",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: invalidToken,
|
||||
session: &invalidTokenSession,
|
||||
channelID: chanID,
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with key and subtopic successfully",
|
||||
topic: &subtopic,
|
||||
|
||||
+32
-17
@@ -9,7 +9,8 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -42,14 +43,16 @@ var _ Service = (*adapterService)(nil)
|
||||
type adapterService struct {
|
||||
clients grpcClientsV1.ClientsServiceClient
|
||||
channels grpcChannelsV1.ChannelsServiceClient
|
||||
authn smqauthn.Authentication
|
||||
pubsub messaging.PubSub
|
||||
}
|
||||
|
||||
// New instantiates the WS adapter implementation.
|
||||
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service {
|
||||
func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub messaging.PubSub) Service {
|
||||
return &adapterService{
|
||||
clients: clients,
|
||||
channels: channels,
|
||||
authn: authn,
|
||||
pubsub: pubsub,
|
||||
}
|
||||
}
|
||||
@@ -89,23 +92,35 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
|
||||
return nil
|
||||
}
|
||||
|
||||
// authorize checks if the clientKey is authorized to access the channel
|
||||
// and returns the clientID if it is.
|
||||
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
|
||||
if strings.HasPrefix(clientKey, "Client") {
|
||||
clientKey = extractClientSecret(clientKey)
|
||||
}
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, clientKey)})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if !authnRes.GetAuthenticated() {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
// authorize checks if the authKey is authorized to access the channel
|
||||
// and returns the clientID or userID if it is.
|
||||
func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
|
||||
var clientID, clientType string
|
||||
switch {
|
||||
case strings.HasPrefix(authKey, apiutil.BearerPrefix):
|
||||
token := strings.TrimPrefix(authKey, apiutil.BearerPrefix)
|
||||
authnSession, err := svc.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
secret := strings.TrimPrefix(authKey, apiutil.ClientPrefix)
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", svcerr.ErrAuthentication
|
||||
}
|
||||
clientType = policies.ClientType
|
||||
clientID = authnRes.GetId()
|
||||
}
|
||||
|
||||
authzReq := &grpcChannelsV1.AuthzReq{
|
||||
ClientType: policies.ClientType,
|
||||
ClientId: authnRes.GetId(),
|
||||
ClientType: clientType,
|
||||
ClientId: clientID,
|
||||
Type: uint32(msgType),
|
||||
ChannelId: chanID,
|
||||
DomainId: domainID,
|
||||
@@ -118,5 +133,5 @@ func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, c
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
return authnRes.GetId(), nil
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
+153
-106
@@ -12,10 +12,12 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
climocks "github.com/absmach/supermq/clients/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -28,18 +30,21 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
chanID = "1"
|
||||
invalidID = "invalidID"
|
||||
invalidKey = "invalidKey"
|
||||
id = "1"
|
||||
clientKey = "client_key"
|
||||
subTopic = "subtopic"
|
||||
protocol = "ws"
|
||||
invalidID = "invalidID"
|
||||
invalidKey = "invalidKey"
|
||||
id = "1"
|
||||
clientKey = "client_key"
|
||||
subTopic = "subtopic"
|
||||
protocol = "ws"
|
||||
token = "Bearer token"
|
||||
invalidToken = "Bearer invalid_token"
|
||||
)
|
||||
|
||||
var (
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
clientID = testsutil.GenerateUUID(&testing.T{})
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
chanID = testsutil.GenerateUUID(&testing.T{})
|
||||
msg = messaging.Message{
|
||||
Channel: chanID,
|
||||
Domain: domainID,
|
||||
@@ -51,136 +56,172 @@ var (
|
||||
sessionID = "sessionID"
|
||||
)
|
||||
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient) {
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
|
||||
pubsub := new(mocks.PubSub)
|
||||
clients := new(climocks.ClientsServiceClient)
|
||||
channels := new(chmocks.ChannelsServiceClient)
|
||||
authn := new(authnmocks.Authentication)
|
||||
|
||||
return ws.New(clients, channels, pubsub), pubsub, clients, channels
|
||||
return ws.New(clients, channels, authn, pubsub), pubsub, clients, channels, authn
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
svc, pubsub, clients, channels := newService()
|
||||
svc, pubsub, clients, channels, auth := newService()
|
||||
|
||||
c := ws.NewClient(slog.Default(), nil, sessionID)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
clientKey string
|
||||
chanID string
|
||||
domainID string
|
||||
subtopic string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNErr error
|
||||
authZRes *grpcChannelsV1.AuthzRes
|
||||
authZErr error
|
||||
subErr error
|
||||
err error
|
||||
desc string
|
||||
authKey string
|
||||
chanID string
|
||||
domainID string
|
||||
subtopic string
|
||||
clientType string
|
||||
clientID string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNErr error
|
||||
authNRes1 smqauthn.Session
|
||||
authZRes *grpcChannelsV1.AuthzRes
|
||||
authZErr error
|
||||
subErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
|
||||
clientKey: clientKey,
|
||||
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with valid token, chanID, subtopic",
|
||||
authKey: token,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNRes1: smqauthn.Session{UserID: userID},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
|
||||
clientKey: clientKey,
|
||||
desc: "subscribe to channel with invalid token",
|
||||
authKey: invalidToken,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with subscribe set to fail",
|
||||
clientKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
subErr: ws.ErrFailedSubscription,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: ws.ErrFailedSubscription,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with invalid clientKey",
|
||||
clientKey: invalidKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty channel",
|
||||
clientKey: clientKey,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty clientKey",
|
||||
clientKey: "",
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "subscribe to channel with subscribe set to fail",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
subErr: ws.ErrFailedSubscription,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: ws.ErrFailedSubscription,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty clientKey and empty channel",
|
||||
clientKey: "",
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "subscribe to channel with invalid clientKey",
|
||||
authKey: invalidKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with invalid channel",
|
||||
clientKey: clientKey,
|
||||
chanID: invalidID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with empty channel",
|
||||
authKey: clientKey,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authentication",
|
||||
clientKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with empty clientKey",
|
||||
authKey: "",
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authorization",
|
||||
clientKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with empty clientKey and empty channel",
|
||||
authKey: "",
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
|
||||
clientKey: "Client " + clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
desc: "subscribe to channel with invalid channel",
|
||||
authKey: clientKey,
|
||||
chanID: invalidID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authentication",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authorization",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
|
||||
authKey: "Client " + clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -188,26 +229,32 @@ func TestSubscribe(t *testing.T) {
|
||||
subConfig := messaging.SubscriberConfig{
|
||||
ID: sessionID,
|
||||
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
|
||||
ClientID: clientID,
|
||||
ClientID: tc.clientID,
|
||||
Handler: c,
|
||||
}
|
||||
authReq := &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, tc.domainID, tc.clientKey)}
|
||||
if strings.HasPrefix(tc.clientKey, "Client") {
|
||||
authReq.Token = authn.AuthPack(authn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.clientKey, "Client "))
|
||||
authReq := &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.authKey)}
|
||||
tc.clientType = policies.ClientType
|
||||
if strings.HasPrefix(tc.authKey, "Client") {
|
||||
authReq.Token = smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.authKey, "Client "))
|
||||
}
|
||||
if strings.HasPrefix(tc.authKey, apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.authKey, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: policies.ClientType,
|
||||
ClientId: tc.authNRes.GetId(),
|
||||
ClientType: tc.clientType,
|
||||
ClientId: tc.clientID,
|
||||
Type: uint32(connections.Subscribe),
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.clientKey, tc.domainID, tc.chanID, tc.subtopic, c)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.authKey, tc.domainID, tc.chanID, tc.subtopic, c)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repoCall.Unset()
|
||||
clientsCall.Unset()
|
||||
authCall.Unset()
|
||||
channelsCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
+101
-89
@@ -39,6 +39,7 @@ const (
|
||||
clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529"
|
||||
protocol = "ws"
|
||||
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
|
||||
validToken = "Bearer token"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,7 +50,8 @@ var (
|
||||
|
||||
func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) (ws.Service, *mocks.PubSub) {
|
||||
pubsub := new(mocks.PubSub)
|
||||
return ws.New(clients, channels, pubsub), pubsub
|
||||
authn := new(authnMocks.Authentication)
|
||||
return ws.New(clients, channels, authn, pubsub), pubsub
|
||||
}
|
||||
|
||||
func newHTTPServer(svc ws.Service, resolver messaging.TopicResolver) *httptest.Server {
|
||||
@@ -77,7 +79,7 @@ func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*ht
|
||||
return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil
|
||||
}
|
||||
|
||||
func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) {
|
||||
func makeURL(tsURL, domainID, chanID, subtopic, authKey string, header bool) (string, error) {
|
||||
u, _ := url.Parse(tsURL)
|
||||
u.Scheme = protocol
|
||||
|
||||
@@ -85,7 +87,7 @@ func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (
|
||||
if header {
|
||||
return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id")
|
||||
}
|
||||
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id")
|
||||
return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, authKey), fmt.Errorf("invalid channel id")
|
||||
}
|
||||
|
||||
subtopicPart := ""
|
||||
@@ -96,16 +98,16 @@ func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (
|
||||
return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil
|
||||
return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, authKey), nil
|
||||
}
|
||||
|
||||
func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
|
||||
func handshake(tsURL, domainID, chanID, subtopic, authKey string, addHeader bool) (*websocket.Conn, *http.Response, error) {
|
||||
header := http.Header{}
|
||||
if addHeader {
|
||||
header.Add("Authorization", clientKey)
|
||||
header.Add("Authorization", authKey)
|
||||
}
|
||||
|
||||
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader)
|
||||
turl, _ := makeURL(tsURL, domainID, chanID, subtopic, authKey, addHeader)
|
||||
conn, res, errRet := websocket.DefaultDialer.Dial(turl, header)
|
||||
|
||||
return conn, res, errRet
|
||||
@@ -138,111 +140,121 @@ func TestHandshake(t *testing.T) {
|
||||
channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
chanID string
|
||||
subtopic string
|
||||
header bool
|
||||
clientKey string
|
||||
status int
|
||||
err error
|
||||
msg []byte
|
||||
desc string
|
||||
domainID string
|
||||
chanID string
|
||||
subtopic string
|
||||
header bool
|
||||
authKey string
|
||||
status int
|
||||
err error
|
||||
msg []byte
|
||||
}{
|
||||
{
|
||||
desc: "connect and send message",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
desc: "connect and send message",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect and send message with clientKey as query parameter",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: false,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
desc: "connect and send message with clientKey as query parameter",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: false,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect and send message that cannot be published",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: []byte{},
|
||||
desc: "connect and send message with valid token",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
authKey: validToken,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect and send message to subtopic",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "subtopic",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
desc: "connect and send message that cannot be published",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: []byte{},
|
||||
},
|
||||
{
|
||||
desc: "connect and send message to nested subtopic",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "subtopic/nested",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
desc: "connect and send message to subtopic",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "subtopic",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect and send message to all subtopics",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: ">",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
desc: "connect and send message to nested subtopic",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "subtopic/nested",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect to empty channel",
|
||||
domainID: domainID,
|
||||
chanID: "",
|
||||
subtopic: "",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusUnauthorized,
|
||||
msg: []byte{},
|
||||
desc: "connect and send message to all subtopics",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: ">",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusSwitchingProtocols,
|
||||
msg: msg,
|
||||
},
|
||||
{
|
||||
desc: "connect with empty clientKey",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
clientKey: "",
|
||||
status: http.StatusUnauthorized,
|
||||
msg: []byte{},
|
||||
desc: "connect to empty channel",
|
||||
domainID: domainID,
|
||||
chanID: "",
|
||||
subtopic: "",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusUnauthorized,
|
||||
msg: []byte{},
|
||||
},
|
||||
{
|
||||
desc: "connect and send message to subtopic with invalid name",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "sub/a*b/topic",
|
||||
header: true,
|
||||
clientKey: clientKey,
|
||||
status: http.StatusUnauthorized,
|
||||
msg: msg,
|
||||
desc: "connect with empty clientKey",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "",
|
||||
header: true,
|
||||
authKey: "",
|
||||
status: http.StatusUnauthorized,
|
||||
msg: []byte{},
|
||||
},
|
||||
{
|
||||
desc: "connect and send message to subtopic with invalid name",
|
||||
domainID: domainID,
|
||||
chanID: id,
|
||||
subtopic: "sub/a*b/topic",
|
||||
header: true,
|
||||
authKey: clientKey,
|
||||
status: http.StatusUnauthorized,
|
||||
msg: msg,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header)
|
||||
conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.authKey, tc.header)
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode))
|
||||
|
||||
if tc.status == http.StatusSwitchingProtocols {
|
||||
|
||||
+2
-2
@@ -56,7 +56,7 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
|
||||
|
||||
go client.Start(ctx)
|
||||
|
||||
if err := svc.Subscribe(ctx, sessionID, req.clientKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
|
||||
if err := svc.Subscribe(ctx, sessionID, req.authKey, req.domainID, req.channelID, req.subtopic, client); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *sl
|
||||
}
|
||||
|
||||
req := connReq{
|
||||
clientKey: authKey,
|
||||
authKey: authKey,
|
||||
channelID: channelID,
|
||||
domainID: domainID,
|
||||
}
|
||||
|
||||
+2
-2
@@ -25,7 +25,7 @@ func LoggingMiddleware(svc ws.Service, logger *slog.Logger) ws.Service {
|
||||
|
||||
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, c *ws.Client) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
@@ -44,7 +44,7 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey
|
||||
lm.logger.Info("Subscribe completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
|
||||
return lm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, c)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) (err error) {
|
||||
|
||||
+2
-2
@@ -31,13 +31,13 @@ func MetricsMiddleware(svc ws.Service, counter metrics.Counter, latency metrics.
|
||||
}
|
||||
|
||||
// Subscribe instruments Subscribe method with metrics.
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, c *ws.Client) error {
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, c *ws.Client) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "subscribe").Add(1)
|
||||
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, c)
|
||||
return mm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, c)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string) error {
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@
|
||||
package api
|
||||
|
||||
type connReq struct {
|
||||
clientKey string
|
||||
authKey string
|
||||
channelID string
|
||||
domainID string
|
||||
subtopic string
|
||||
|
||||
+28
-36
@@ -16,7 +16,6 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -85,25 +84,17 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
|
||||
return errClientNotInitialized
|
||||
}
|
||||
|
||||
var token string
|
||||
switch {
|
||||
case strings.HasPrefix(string(s.Password), "Client"):
|
||||
token = strings.ReplaceAll(string(s.Password), "Client ", "")
|
||||
default:
|
||||
token = string(s.Password)
|
||||
}
|
||||
|
||||
domainID, channelID, _, err := h.parser.ParsePublishTopic(ctx, *topic, true)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
|
||||
clientID, clientType, err := h.authAccess(ctx, token, domainID, channelID, connections.Publish)
|
||||
clientID, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.Username == "" && clientType == policies.ClientType {
|
||||
if s.Username == "" {
|
||||
s.Username = clientID
|
||||
}
|
||||
|
||||
@@ -126,7 +117,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil {
|
||||
if _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -194,19 +185,29 @@ func (h *handler) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) {
|
||||
if strings.HasPrefix(token, "Client") {
|
||||
token = extractClientSecret(token)
|
||||
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, error) {
|
||||
var clientID, clientType string
|
||||
switch {
|
||||
case strings.HasPrefix(token, apiutil.BearerPrefix):
|
||||
token := strings.TrimPrefix(token, apiutil.BearerPrefix)
|
||||
authnSession, err := h.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
secret := strings.TrimPrefix(token, apiutil.ClientPrefix)
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
if err != nil {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
clientType = policies.ClientType
|
||||
clientID = authnRes.GetId()
|
||||
}
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, token)})
|
||||
if err != nil {
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
if !authnRes.GetAuthenticated() {
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
clientType := policies.ClientType
|
||||
clientID := authnRes.GetId()
|
||||
|
||||
ar := &grpcChannelsV1.AuthzReq{
|
||||
Type: uint32(msgType),
|
||||
@@ -217,20 +218,11 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
|
||||
}
|
||||
res, err := h.channels.Authorize(ctx, ar)
|
||||
if err != nil {
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
if !res.GetAuthorized() {
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
|
||||
return clientID, clientType, nil
|
||||
}
|
||||
|
||||
// extractClientSecret returns value of the client secret. If there is no client key - an empty value is returned.
|
||||
func extractClientSecret(token string) string {
|
||||
if !strings.HasPrefix(token, apiutil.ClientPrefix) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(token, apiutil.ClientPrefix)
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package ws_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mgate "github.com/absmach/mgate/pkg/http"
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
clmocks "github.com/absmach/supermq/clients/mocks"
|
||||
dmocks "github.com/absmach/supermq/domains/mocks"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
"github.com/absmach/supermq/ws"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var (
|
||||
invalidValue = "invalid"
|
||||
topicMsg = "/m/%s/c/%s"
|
||||
subtopicMsg = "/m/%s/c/%s/subtopic"
|
||||
topic = fmt.Sprintf(topicMsg, domainID, chanID)
|
||||
subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID)
|
||||
invalidTopic = invalidValue
|
||||
topics = []string{topic}
|
||||
payload = []byte("[{'n':'test-name', 'v': 1.2}]")
|
||||
sessionClient = session.Session{
|
||||
ID: clientID,
|
||||
Password: []byte(clientKey),
|
||||
}
|
||||
invalidChannelIDTopic = "m/**/c"
|
||||
validToken = "token"
|
||||
errClientNotInitialized = errors.New("client is not initialized")
|
||||
errMissingTopicPub = errors.New("failed to publish due to missing topic")
|
||||
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
|
||||
)
|
||||
|
||||
var (
|
||||
clients = new(clmocks.ClientsServiceClient)
|
||||
channels = new(chmocks.ChannelsServiceClient)
|
||||
authn = new(authnmocks.Authentication)
|
||||
publisher = new(mocks.PubSub)
|
||||
domains = new(dmocks.DomainsServiceClient)
|
||||
)
|
||||
|
||||
func newHandler(t *testing.T) session.Handler {
|
||||
logger := smqlog.NewMock()
|
||||
authn = new(authnmocks.Authentication)
|
||||
clients = new(clmocks.ClientsServiceClient)
|
||||
channels = new(chmocks.ChannelsServiceClient)
|
||||
publisher = new(mocks.PubSub)
|
||||
parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains)
|
||||
assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err))
|
||||
|
||||
return ws.NewHandler(publisher, logger, authn, clients, channels, parser)
|
||||
}
|
||||
|
||||
func TestAuthPublish(t *testing.T) {
|
||||
handler := newHandler(t)
|
||||
|
||||
clientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
invalidClientKeySession := session.Session{
|
||||
Password: []byte("Client " + invalidKey),
|
||||
}
|
||||
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
session *session.Session
|
||||
topic *string
|
||||
payload *[]byte
|
||||
authKey string
|
||||
status int
|
||||
clientType string
|
||||
chanID string
|
||||
domainID string
|
||||
clientID string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNRes1 smqauthn.Session
|
||||
authNErr error
|
||||
authZRes *grpcChannelsV1.AuthzRes
|
||||
authZErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish with client key successfully",
|
||||
session: &clientKeySession,
|
||||
topic: &topic,
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid client key",
|
||||
session: &invalidClientKeySession,
|
||||
topic: &topic,
|
||||
authKey: invalidKey,
|
||||
payload: &payload,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with nil session",
|
||||
session: nil,
|
||||
topic: &topic,
|
||||
authKey: clientKey,
|
||||
status: http.StatusInternalServerError,
|
||||
err: errClientNotInitialized,
|
||||
},
|
||||
{
|
||||
desc: "publish with empty topic",
|
||||
session: &clientKeySession,
|
||||
topic: nil,
|
||||
authKey: clientKey,
|
||||
status: http.StatusBadRequest,
|
||||
err: errMissingTopicPub,
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client key",
|
||||
session: &clientKeySession,
|
||||
topic: &topic,
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with token successfully",
|
||||
session: &tokenSession,
|
||||
topic: &topic,
|
||||
authKey: token,
|
||||
payload: &payload,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{UserID: userID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid token",
|
||||
session: &invalidTokenSession,
|
||||
topic: &topic,
|
||||
authKey: invalidToken,
|
||||
payload: &payload,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client key",
|
||||
session: &tokenSession,
|
||||
topic: &topic,
|
||||
authKey: token,
|
||||
payload: &payload,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{UserID: userID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
}
|
||||
tc.clientType = policies.ClientType
|
||||
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
ClientId: tc.clientID,
|
||||
Type: uint32(connections.Publish),
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
err := handler.AuthPublish(ctx, tc.topic, tc.payload)
|
||||
hpe, ok := err.(mgate.HTTPProxyError)
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
authCall.Unset()
|
||||
clientsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSubscribe(t *testing.T) {
|
||||
handler := newHandler(t)
|
||||
|
||||
clientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
invalidClientKeySession := session.Session{
|
||||
Password: []byte("Client " + invalidKey),
|
||||
}
|
||||
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
session *session.Session
|
||||
topics *[]string
|
||||
authKey string
|
||||
status int
|
||||
clientType string
|
||||
chanID string
|
||||
domainID string
|
||||
clientID string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNRes1 smqauthn.Session
|
||||
authNErr error
|
||||
authZRes *grpcChannelsV1.AuthzRes
|
||||
authZErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "subscribe with client key successfully",
|
||||
session: &clientKeySession,
|
||||
topics: &topics,
|
||||
authKey: clientKey,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with invalid client key",
|
||||
session: &invalidClientKeySession,
|
||||
topics: &topics,
|
||||
authKey: invalidKey,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with empty topics",
|
||||
session: &clientKeySession,
|
||||
topics: nil,
|
||||
authKey: clientKey,
|
||||
status: http.StatusBadRequest,
|
||||
err: errMissingTopicSub,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with nil session",
|
||||
session: nil,
|
||||
topics: &topics,
|
||||
authKey: clientKey,
|
||||
status: http.StatusInternalServerError,
|
||||
err: errClientNotInitialized,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with unauthorized client key",
|
||||
session: &clientKeySession,
|
||||
topics: &topics,
|
||||
authKey: clientKey,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with token successfully",
|
||||
session: &tokenSession,
|
||||
topics: &topics,
|
||||
authKey: token,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{UserID: userID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with invalid token",
|
||||
session: &invalidTokenSession,
|
||||
topics: &topics,
|
||||
authKey: invalidToken,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with unauthorized client key",
|
||||
session: &tokenSession,
|
||||
topics: &topics,
|
||||
authKey: token,
|
||||
clientType: policies.UserType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNRes1: smqauthn.Session{UserID: userID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
}
|
||||
tc.clientType = policies.ClientType
|
||||
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
ClientId: tc.clientID,
|
||||
Type: uint32(connections.Subscribe),
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
err := handler.AuthSubscribe(ctx, tc.topics)
|
||||
hpe, ok := err.(mgate.HTTPProxyError)
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
authCall.Unset()
|
||||
clientsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
handler := newHandler(t)
|
||||
|
||||
malformedSubtopics := topic + "/" + subtopic + "%"
|
||||
wrongCharSubtopics := topic + "/" + subtopic + ">"
|
||||
validSubtopic := topic + "/" + subtopic
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session *session.Session
|
||||
topic string
|
||||
payload []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish without active session",
|
||||
session: nil,
|
||||
topic: topic,
|
||||
payload: payload,
|
||||
err: errClientNotInitialized,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid topic",
|
||||
session: &sessionClient,
|
||||
topic: invalidTopic,
|
||||
payload: payload,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid channel ID",
|
||||
session: &sessionClient,
|
||||
topic: invalidChannelIDTopic,
|
||||
payload: payload,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish with malformed subtopic",
|
||||
session: &sessionClient,
|
||||
topic: malformedSubtopics,
|
||||
payload: payload,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish with subtopic containing wrong character",
|
||||
session: &sessionClient,
|
||||
topic: wrongCharSubtopics,
|
||||
payload: payload,
|
||||
err: messaging.ErrMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish with subtopic",
|
||||
session: &sessionClient,
|
||||
topic: validSubtopic,
|
||||
payload: payload,
|
||||
},
|
||||
{
|
||||
desc: "publish without subtopic",
|
||||
session: &sessionClient,
|
||||
topic: topic,
|
||||
payload: payload,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ctx := context.TODO()
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
}
|
||||
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
err := handler.Publish(ctx, &tc.topic, &tc.payload)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user