mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-2799 - Add support for basic auth for HTTP and WS adapters (#3049)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
@@ -8,11 +8,16 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BearerPrefix represents the token prefix for Bearer authentication scheme.
|
||||
const BearerPrefix = "Bearer "
|
||||
const (
|
||||
// BearerPrefix represents the token prefix for Bearer authentication scheme.
|
||||
BearerPrefix = "Bearer "
|
||||
|
||||
// ClientPrefix represents the key prefix for Client authentication scheme.
|
||||
const ClientPrefix = "Client "
|
||||
// ClientPrefix represents the key prefix for Client authentication scheme.
|
||||
ClientPrefix = "Client "
|
||||
|
||||
// BasicAuthPrefix represents the prefix for Basic authentication scheme.
|
||||
BasicAuthPrefix = "Basic "
|
||||
)
|
||||
|
||||
// ExtractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned.
|
||||
func ExtractBearerToken(r *http.Request) string {
|
||||
|
||||
+71
-34
@@ -5,6 +5,7 @@ package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -26,19 +27,12 @@ import (
|
||||
|
||||
var _ session.Handler = (*handler)(nil)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
protocol = "http"
|
||||
clientIDCtxKey ctxKey = "client_id"
|
||||
clientTypeCtxKey ctxKey = "client_type"
|
||||
)
|
||||
const protocol = "http"
|
||||
|
||||
// Log message formats.
|
||||
const (
|
||||
logInfoPublished = "published with client_type %s client_id %s to the topic %s"
|
||||
logInfoFailedAuthNToken = "failed to authenticate token for topic %s with error %s"
|
||||
logInfoFailedAuthNClient = "failed to authenticate client key %s for topic %s with error %s"
|
||||
publishedInfoFmt = "published with client_type %s client_id %s to the topic %s"
|
||||
failedAuthnFmt = "failed to authenticate client_type %s for topic %s with error %s"
|
||||
)
|
||||
|
||||
// Error wrappers for MQTT errors.
|
||||
@@ -48,6 +42,8 @@ var (
|
||||
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
|
||||
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
|
||||
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
|
||||
errInvalidAuthFormat = errors.New("invalid basic auth format")
|
||||
errInvalidClientType = errors.New("invalid client type")
|
||||
)
|
||||
|
||||
// Event implements events.Event interface.
|
||||
@@ -118,37 +114,39 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
return errors.Wrap(errMalformedTopic, err)
|
||||
}
|
||||
|
||||
var clientID, clientType string
|
||||
var token, clientType string
|
||||
pass := string(s.Password)
|
||||
switch {
|
||||
case strings.HasPrefix(string(s.Password), "Client"):
|
||||
secret := strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
if err != nil {
|
||||
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, svcerr.ErrAuthentication))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
case s.Username != "" && pass != "":
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, s.Username, pass)
|
||||
clientType = policies.ClientType
|
||||
clientID = authnRes.GetId()
|
||||
case strings.HasPrefix(string(s.Password), apiutil.BearerPrefix):
|
||||
token := strings.TrimPrefix(string(s.Password), apiutil.BearerPrefix)
|
||||
authnSession, err := h.authn.Authenticate(ctx, token)
|
||||
case strings.HasPrefix(pass, apiutil.BasicAuthPrefix):
|
||||
username, password, err := decodeAuth(strings.TrimPrefix(pass, apiutil.BasicAuthPrefix))
|
||||
if err != nil {
|
||||
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNToken, *topic, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, policies.ClientType, *topic, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
|
||||
}
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
|
||||
clientType = policies.ClientType
|
||||
case strings.HasPrefix(pass, apiutil.ClientPrefix):
|
||||
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(pass, apiutil.ClientPrefix))
|
||||
clientType = policies.ClientType
|
||||
case strings.HasPrefix(pass, apiutil.BearerPrefix):
|
||||
token = strings.TrimPrefix(pass, apiutil.BearerPrefix)
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
|
||||
id, err := h.authenticate(ctx, clientType, token)
|
||||
if err != nil {
|
||||
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, clientType, *topic, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
||||
// Health topics are not published to message broker.
|
||||
if topicType == messaging.HealthType {
|
||||
h.logger.Info(fmt.Sprintf(logInfoPublished, clientType, clientID, *topic))
|
||||
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,21 +155,21 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
Domain: domainID,
|
||||
Channel: channelID,
|
||||
Subtopic: subtopic,
|
||||
Publisher: clientID,
|
||||
Publisher: id,
|
||||
Payload: *payload,
|
||||
Created: time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
ar := &grpcChannelsV1.AuthzReq{
|
||||
DomainId: domainID,
|
||||
ClientId: clientID,
|
||||
ClientId: id,
|
||||
ClientType: clientType,
|
||||
ChannelId: msg.Channel,
|
||||
Type: uint32(connections.Publish),
|
||||
}
|
||||
res, err := h.channels.Authorize(ctx, ar)
|
||||
if err != nil {
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, err)
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
|
||||
}
|
||||
if !res.GetAuthorized() {
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization)
|
||||
@@ -181,7 +179,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
return errors.Wrap(errFailedPublishToMsgBroker, err)
|
||||
}
|
||||
|
||||
h.logger.Info(fmt.Sprintf(logInfoPublished, clientType, clientID, *topic))
|
||||
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -200,3 +198,42 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
|
||||
func (h *handler) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
|
||||
switch authType {
|
||||
case policies.UserType:
|
||||
authnSession, err := h.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return authnSession.UserID, nil
|
||||
case policies.ClientType:
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
return authnRes.GetId(), nil
|
||||
default:
|
||||
return "", errInvalidClientType
|
||||
}
|
||||
}
|
||||
|
||||
// decodeAuth decodes the base64 encoded string in the format "clientID:secret".
|
||||
func decodeAuth(s string) (string, string, error) {
|
||||
db, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parts := strings.SplitN(string(db), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", errInvalidAuthFormat
|
||||
}
|
||||
clientID := parts[0]
|
||||
secret := parts[1]
|
||||
|
||||
return clientID, secret, nil
|
||||
}
|
||||
|
||||
+142
-76
@@ -5,6 +5,7 @@ package http_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -145,21 +146,34 @@ func TestPublish(t *testing.T) {
|
||||
clientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
basicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(clientKey),
|
||||
}
|
||||
invalidBasicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(invalidValue),
|
||||
}
|
||||
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
|
||||
encodedCredsSession := session.Session{
|
||||
Password: []byte(apiutil.BasicAuthPrefix + creds),
|
||||
}
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic *string
|
||||
channelID string
|
||||
username string
|
||||
payload *[]byte
|
||||
password string
|
||||
session *session.Session
|
||||
status int
|
||||
authNToken string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNRes1 smqauthn.Session
|
||||
authNErr error
|
||||
@@ -169,17 +183,18 @@ func TestPublish(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish with key successfully",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
err: nil,
|
||||
desc: "publish with key successfully",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with token successfully",
|
||||
@@ -207,17 +222,18 @@ func TestPublish(t *testing.T) {
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with key and subtopic successfully",
|
||||
topic: &subtopic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
err: nil,
|
||||
desc: "publish with key and subtopic successfully",
|
||||
topic: &subtopic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with empty topic",
|
||||
@@ -238,14 +254,15 @@ func TestPublish(t *testing.T) {
|
||||
err: errClientNotInitialized,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid topic",
|
||||
topic: &invalidTopic,
|
||||
status: http.StatusBadRequest,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
err: errMalformedTopic,
|
||||
desc: "publish with invalid topic",
|
||||
topic: &invalidTopic,
|
||||
status: http.StatusBadRequest,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
err: errMalformedTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish with malformwd subtopic",
|
||||
@@ -269,28 +286,30 @@ func TestPublish(t *testing.T) {
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with client key and failed to authenticate",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
status: http.StatusUnauthorized,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
|
||||
authNErr: nil,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "publish with client key and failed to authenticate",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
status: http.StatusUnauthorized,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
|
||||
authNErr: nil,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with client key and failed to authenticate with error",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
status: http.StatusUnauthorized,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "publish with client key and failed to authenticate with error",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
status: http.StatusUnauthorized,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with token and failed to authenticate",
|
||||
@@ -305,32 +324,78 @@ func TestPublish(t *testing.T) {
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
status: http.StatusUnauthorized,
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: nil,
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "publish with basic auth",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
username: clientID,
|
||||
password: clientKey,
|
||||
session: &basicAuthSession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with authorization error",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
status: http.StatusBadRequest,
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "publish with invalid basic auth",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
username: clientID,
|
||||
password: clientKey,
|
||||
session: &invalidBasicAuthSession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with encoded credentials",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
username: clientID,
|
||||
password: clientKey,
|
||||
session: &encodedCredsSession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
authZErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
status: http.StatusUnauthorized,
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: nil,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "publish with authorization error",
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
status: http.StatusUnauthorized,
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "publish with failed to publish",
|
||||
@@ -339,6 +404,7 @@ func TestPublish(t *testing.T) {
|
||||
password: clientKey,
|
||||
session: &clientKeySession,
|
||||
channelID: chanID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
@@ -375,7 +441,7 @@ func TestPublish(t *testing.T) {
|
||||
if tc.topic != nil {
|
||||
internalTopic = strings.TrimPrefix(strings.ReplaceAll(*tc.topic, "/", "."), ".m.")
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.password)}).Return(tc.authNRes, tc.authNErr)
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr)
|
||||
repoCall := publisher.On("Publish", ctx, internalTopic, mock.Anything).Return(tc.publishErr)
|
||||
|
||||
@@ -121,7 +121,6 @@ func TestSendMessage(t *testing.T) {
|
||||
msg: msg,
|
||||
secret: "",
|
||||
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
|
||||
authErr: svcerr.ErrAuthentication,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
@@ -132,7 +131,6 @@ func TestSendMessage(t *testing.T) {
|
||||
msg: msg,
|
||||
secret: "invalid",
|
||||
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
|
||||
authErr: svcerr.ErrAuthentication,
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
@@ -143,7 +141,6 @@ func TestSendMessage(t *testing.T) {
|
||||
msg: msg,
|
||||
secret: clientKey,
|
||||
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
|
||||
authErr: svcerr.ErrAuthentication,
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
@@ -176,7 +173,6 @@ func TestSendMessage(t *testing.T) {
|
||||
msg: msg,
|
||||
secret: clientKey,
|
||||
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
|
||||
authErr: svcerr.ErrAuthentication,
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
|
||||
+49
-23
@@ -33,7 +33,7 @@ type Service interface {
|
||||
// the channelID for subscription and domainID specifies the domain for authorization.
|
||||
// Subtopic is optional.
|
||||
// If the subscription is successful, nil is returned otherwise error is returned.
|
||||
Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, topicType messaging.TopicType, client *Client) error
|
||||
Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *Client) error
|
||||
|
||||
Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error
|
||||
}
|
||||
@@ -57,12 +57,12 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, channelID, subtopic string, topicType messaging.TopicType, c *Client) error {
|
||||
if (channelID == "" && topicType != messaging.HealthType) || clientKey == "" || domainID == "" {
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, username, password, domainID, channelID, subtopic string, topicType messaging.TopicType, c *Client) error {
|
||||
if (channelID == "" && topicType != messaging.HealthType) || password == "" || domainID == "" {
|
||||
return svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
clientID, err := svc.authorize(ctx, clientKey, domainID, channelID, connections.Subscribe, topicType)
|
||||
clientID, err := svc.authorize(ctx, username, password, domainID, channelID, connections.Subscribe, topicType)
|
||||
if err != nil {
|
||||
return svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -102,38 +102,41 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
|
||||
|
||||
// authorize checks if the authKey is authorized to access the channel
|
||||
// and returns the clientID or userID if it is.
|
||||
func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
|
||||
var clientID, clientType string
|
||||
func (svc *adapterService) authorize(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
|
||||
var token, clientType string
|
||||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(authKey, apiutil.BearerPrefix):
|
||||
token := strings.TrimPrefix(authKey, apiutil.BearerPrefix)
|
||||
authnSession, err := svc.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
case strings.HasPrefix(password, apiutil.BearerPrefix):
|
||||
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
secret := strings.TrimPrefix(authKey, apiutil.ClientPrefix)
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
case username != "" && password != "":
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
|
||||
clientType = policies.ClientType
|
||||
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
|
||||
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", svcerr.ErrAuthentication
|
||||
}
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
|
||||
clientType = policies.ClientType
|
||||
clientID = authnRes.GetId()
|
||||
default:
|
||||
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
|
||||
clientType = policies.ClientType
|
||||
}
|
||||
|
||||
id, err := svc.authenticate(ctx, clientType, token)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
// Health check topics do not require channel authorization.
|
||||
if topicType == messaging.HealthType {
|
||||
return clientID, nil
|
||||
return id, nil
|
||||
}
|
||||
|
||||
authzReq := &grpcChannelsV1.AuthzReq{
|
||||
ClientType: clientType,
|
||||
ClientId: clientID,
|
||||
ClientId: id,
|
||||
Type: uint32(msgType),
|
||||
ChannelId: chanID,
|
||||
DomainId: domainID,
|
||||
@@ -146,5 +149,28 @@ func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, cha
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (svc *adapterService) authenticate(ctx context.Context, authType, token string) (string, error) {
|
||||
switch authType {
|
||||
case policies.UserType:
|
||||
authnSession, err := svc.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return authnSession.UserID, nil
|
||||
case policies.ClientType:
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
return authnRes.GetId(), nil
|
||||
default:
|
||||
return "", errInvalidClientType
|
||||
}
|
||||
}
|
||||
|
||||
+194
-128
@@ -5,6 +5,7 @@ package ws_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -53,7 +54,9 @@ var (
|
||||
Protocol: protocol,
|
||||
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
|
||||
}
|
||||
sessionID = "sessionID"
|
||||
sessionID = "sessionID"
|
||||
validEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
|
||||
invalidEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", invalidID, invalidKey)))
|
||||
)
|
||||
|
||||
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
|
||||
@@ -72,13 +75,15 @@ func TestSubscribe(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
authKey string
|
||||
username string
|
||||
password string
|
||||
chanID string
|
||||
domainID string
|
||||
subtopic string
|
||||
clientType string
|
||||
clientID string
|
||||
topicType messaging.TopicType
|
||||
authNToken string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNErr error
|
||||
authNRes1 smqauthn.Session
|
||||
@@ -88,20 +93,21 @@ func TestSubscribe(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with valid token, chanID, subtopic",
|
||||
authKey: token,
|
||||
password: token,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
@@ -113,7 +119,7 @@ func TestSubscribe(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with invalid token",
|
||||
authKey: invalidToken,
|
||||
password: invalidToken,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
subtopic: subTopic,
|
||||
@@ -123,45 +129,48 @@ func TestSubscribe(t *testing.T) {
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with subscribe set to fail",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
subErr: ws.ErrFailedSubscription,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: ws.ErrFailedSubscription,
|
||||
desc: "subscribe to channel with subscribe set to fail",
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
subErr: ws.ErrFailedSubscription,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: ws.ErrFailedSubscription,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with invalid clientKey",
|
||||
authKey: invalidKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with invalid clientKey",
|
||||
password: invalidKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty channel",
|
||||
authKey: clientKey,
|
||||
password: clientKey,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
@@ -171,7 +180,7 @@ func TestSubscribe(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty clientKey",
|
||||
authKey: "",
|
||||
password: "",
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
@@ -181,7 +190,7 @@ func TestSubscribe(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with empty clientKey and empty channel",
|
||||
authKey: "",
|
||||
password: "",
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
@@ -190,66 +199,125 @@ func TestSubscribe(t *testing.T) {
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with invalid channel",
|
||||
authKey: clientKey,
|
||||
chanID: invalidID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with invalid channel",
|
||||
password: clientKey,
|
||||
chanID: invalidID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authZErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authentication",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with failed authentication",
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with failed authorization",
|
||||
authKey: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
desc: "subscribe to channel with failed authorization",
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
|
||||
authKey: "Client " + clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
|
||||
password: "Client " + clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to health check topic with empty channel and valid clientKey",
|
||||
authKey: clientKey,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
topicType: messaging.HealthType,
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
err: nil,
|
||||
desc: "subscribe to channel with basic auth",
|
||||
username: clientID,
|
||||
password: clientKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with basic auth and invalid credentials",
|
||||
username: invalidID,
|
||||
password: invalidKey,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: invalidID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with b64 encoded credentials",
|
||||
password: apiutil.BasicAuthPrefix + validEncodedCreds,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to channel with b64 encoded credentials and invalid credentials",
|
||||
password: apiutil.BasicAuthPrefix + invalidEncodedCreds,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: invalidID,
|
||||
subtopic: subTopic,
|
||||
topicType: messaging.MessageType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to health check topic with empty channel and valid clientKey",
|
||||
password: clientKey,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
topicType: messaging.HealthType,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe to health check topic with empty channel and valid token",
|
||||
authKey: token,
|
||||
password: token,
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
@@ -259,7 +327,7 @@ func TestSubscribe(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "subscribe to health check topic with empty domain and valid clientKey",
|
||||
authKey: clientKey,
|
||||
password: clientKey,
|
||||
chanID: "",
|
||||
domainID: "",
|
||||
clientID: clientID,
|
||||
@@ -270,35 +338,33 @@ func TestSubscribe(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
subConfig := messaging.SubscriberConfig{
|
||||
ID: sessionID,
|
||||
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
|
||||
ClientID: tc.clientID,
|
||||
Handler: c,
|
||||
}
|
||||
authReq := &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.authKey)}
|
||||
tc.clientType = policies.ClientType
|
||||
if strings.HasPrefix(tc.authKey, "Client") {
|
||||
authReq.Token = smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.authKey, "Client "))
|
||||
}
|
||||
if strings.HasPrefix(tc.authKey, apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.authKey, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
ClientId: tc.clientID,
|
||||
Type: uint32(connections.Subscribe),
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.authKey, tc.domainID, tc.chanID, tc.subtopic, tc.topicType, c)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repoCall.Unset()
|
||||
clientsCall.Unset()
|
||||
authCall.Unset()
|
||||
channelsCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
subConfig := messaging.SubscriberConfig{
|
||||
ID: sessionID,
|
||||
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
|
||||
ClientID: tc.clientID,
|
||||
Handler: c,
|
||||
}
|
||||
tc.clientType = policies.ClientType
|
||||
if strings.HasPrefix(tc.password, apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.password, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
ClientId: tc.clientID,
|
||||
Type: uint32(connections.Subscribe),
|
||||
ChannelId: tc.chanID,
|
||||
DomainId: tc.domainID,
|
||||
}).Return(tc.authZRes, tc.authZErr)
|
||||
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
|
||||
err := svc.Subscribe(context.Background(), sessionID, tc.username, tc.password, tc.domainID, tc.chanID, tc.subtopic, tc.topicType, c)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repoCall.Unset()
|
||||
clientsCall.Unset()
|
||||
authCall.Unset()
|
||||
channelsCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+16
-7
@@ -17,6 +17,11 @@ import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
authzHeaderKey = "Authorization"
|
||||
authzQueryKey = "authorization"
|
||||
)
|
||||
|
||||
var errGenSessionID = errors.New("failed to generate session id")
|
||||
|
||||
func generateSessionID() (string, error) {
|
||||
@@ -56,7 +61,7 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
|
||||
|
||||
go client.Start(ctx)
|
||||
|
||||
if err := svc.Subscribe(ctx, sessionID, req.authKey, req.domainID, req.channelID, req.subtopic, topicType, client); err != nil {
|
||||
if err := svc.Subscribe(ctx, sessionID, req.username, req.password, req.domainID, req.channelID, req.subtopic, topicType, client); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -66,14 +71,17 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
|
||||
}
|
||||
|
||||
func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
|
||||
authKey := r.Header.Get("Authorization")
|
||||
if authKey == "" {
|
||||
authKeys := r.URL.Query()["authorization"]
|
||||
if len(authKeys) == 0 {
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
switch {
|
||||
case r.URL.Query().Get(authzQueryKey) != "":
|
||||
password = r.URL.Query().Get(authzQueryKey)
|
||||
case r.Header.Get(authzHeaderKey) != "":
|
||||
password = r.Header.Get(authzHeaderKey)
|
||||
default:
|
||||
logger.Debug("Missing authorization key.")
|
||||
return connReq{}, errUnauthorizedAccess
|
||||
}
|
||||
authKey = authKeys[0]
|
||||
}
|
||||
|
||||
domain := chi.URLParam(r, "domain")
|
||||
@@ -85,7 +93,8 @@ func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *sl
|
||||
}
|
||||
|
||||
req := connReq{
|
||||
authKey: authKey,
|
||||
username: username,
|
||||
password: password,
|
||||
channelID: channelID,
|
||||
domainID: domainID,
|
||||
}
|
||||
|
||||
+2
-1
@@ -4,7 +4,8 @@
|
||||
package api
|
||||
|
||||
type connReq struct {
|
||||
authKey string
|
||||
username string
|
||||
password string
|
||||
channelID string
|
||||
domainID string
|
||||
subtopic string
|
||||
|
||||
+69
-24
@@ -5,6 +5,7 @@ package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
@@ -43,6 +44,8 @@ var (
|
||||
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
|
||||
errFailedPublish = errors.New("failed to publish")
|
||||
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
|
||||
errInvalidAuthFormat = errors.New("invalid basic auth format")
|
||||
errInvalidClientType = errors.New("invalid client type")
|
||||
)
|
||||
|
||||
// Event implements events.Event interface.
|
||||
@@ -89,7 +92,7 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
|
||||
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
|
||||
}
|
||||
|
||||
clientID, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish, topicType)
|
||||
clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +120,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
|
||||
if _, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -188,38 +191,41 @@ func (h *handler) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
|
||||
var clientID, clientType string
|
||||
func (h *handler) authAccess(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
|
||||
var token, clientType string
|
||||
var err error
|
||||
switch {
|
||||
case strings.HasPrefix(token, apiutil.BearerPrefix):
|
||||
token := strings.TrimPrefix(token, apiutil.BearerPrefix)
|
||||
authnSession, err := h.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
case strings.HasPrefix(password, apiutil.BearerPrefix):
|
||||
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
|
||||
clientType = policies.UserType
|
||||
clientID = authnSession.UserID
|
||||
default:
|
||||
secret := strings.TrimPrefix(token, apiutil.ClientPrefix)
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
if err != nil {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
case username != "" && password != "":
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
|
||||
clientType = policies.ClientType
|
||||
clientID = authnRes.GetId()
|
||||
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
|
||||
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
|
||||
clientType = policies.ClientType
|
||||
default:
|
||||
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
|
||||
clientType = policies.ClientType
|
||||
}
|
||||
|
||||
id, err := h.authenticate(ctx, clientType, token)
|
||||
if err != nil {
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
|
||||
// Health check topics do not require channel authorization.
|
||||
if topicType == messaging.HealthType {
|
||||
return clientID, nil
|
||||
return id, nil
|
||||
}
|
||||
|
||||
ar := &grpcChannelsV1.AuthzReq{
|
||||
Type: uint32(msgType),
|
||||
ClientId: clientID,
|
||||
ClientId: id,
|
||||
ClientType: clientType,
|
||||
ChannelId: chanID,
|
||||
DomainId: domainID,
|
||||
@@ -232,5 +238,44 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
|
||||
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
|
||||
switch authType {
|
||||
case policies.UserType:
|
||||
authnSession, err := h.authn.Authenticate(ctx, token)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return authnSession.UserID, nil
|
||||
case policies.ClientType:
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !authnRes.Authenticated {
|
||||
return "", svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
return authnRes.GetId(), nil
|
||||
default:
|
||||
return "", errInvalidClientType
|
||||
}
|
||||
}
|
||||
|
||||
// decodeAuth decodes the base64 encoded string in the format "clientID:secret".
|
||||
func decodeAuth(s string) (string, string, error) {
|
||||
db, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parts := strings.SplitN(string(db), ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", "", errInvalidAuthFormat
|
||||
}
|
||||
clientID := parts[0]
|
||||
secret := parts[1]
|
||||
|
||||
return clientID, secret, nil
|
||||
}
|
||||
|
||||
+189
-16
@@ -5,6 +5,7 @@ package ws_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -82,16 +83,37 @@ func TestAuthPublish(t *testing.T) {
|
||||
clientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
unauthorizedKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
invalidClientKeySession := session.Session{
|
||||
Password: []byte("Client " + invalidKey),
|
||||
}
|
||||
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
basicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(clientKey),
|
||||
}
|
||||
invalidBasicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(invalidValue),
|
||||
}
|
||||
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
|
||||
encodedCredsSession := session.Session{
|
||||
Password: []byte(apiutil.BasicAuthPrefix + creds),
|
||||
}
|
||||
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
|
||||
invalidEncodedCredsSession := session.Session{
|
||||
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
|
||||
}
|
||||
hcClientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
@@ -104,6 +126,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
chanID string
|
||||
domainID string
|
||||
clientID string
|
||||
authNToken string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNRes1 smqauthn.Session
|
||||
authNErr error
|
||||
@@ -122,6 +145,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
@@ -137,9 +161,10 @@ func TestAuthPublish(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with nil session",
|
||||
@@ -159,7 +184,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client key",
|
||||
session: &clientKeySession,
|
||||
session: &unauthorizedKeySession,
|
||||
topic: &topic,
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
@@ -167,6 +192,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
@@ -202,10 +228,10 @@ func TestAuthPublish(t *testing.T) {
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized client key",
|
||||
desc: "publish with unauthorized token",
|
||||
session: &tokenSession,
|
||||
topic: &topic,
|
||||
authKey: token,
|
||||
@@ -220,9 +246,71 @@ func TestAuthPublish(t *testing.T) {
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with basic auth successfully",
|
||||
session: &basicAuthSession,
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid basic auth",
|
||||
session: &invalidBasicAuthSession,
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
authKey: invalidValue,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with b64 encoded credentials",
|
||||
session: &encodedCredsSession,
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid b64 encoded credentials",
|
||||
session: &invalidEncodedCredsSession,
|
||||
topic: &topic,
|
||||
payload: &payload,
|
||||
authKey: invalidValue,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with health check topic successfully",
|
||||
session: &clientKeySession,
|
||||
session: &hcClientKeySession,
|
||||
topic: &hcTopic,
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
@@ -231,12 +319,13 @@ func TestAuthPublish(t *testing.T) {
|
||||
chanID: "",
|
||||
domainID: domainID,
|
||||
clientID: userID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: true},
|
||||
authNErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid health check topic",
|
||||
session: &clientKeySession,
|
||||
session: &hcClientKeySession,
|
||||
topic: &invalidHCTopic,
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
@@ -256,7 +345,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
@@ -284,16 +373,37 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
clientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
unauthorizedKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
invalidClientKeySession := session.Session{
|
||||
Password: []byte("Client " + invalidKey),
|
||||
}
|
||||
|
||||
tokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + validToken),
|
||||
}
|
||||
invalidTokenSession := session.Session{
|
||||
Password: []byte(apiutil.BearerPrefix + invalidToken),
|
||||
}
|
||||
basicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(clientKey),
|
||||
}
|
||||
invalidBasicAuthSession := session.Session{
|
||||
Username: clientID,
|
||||
Password: []byte(invalidValue),
|
||||
}
|
||||
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
|
||||
encodedCredsSession := session.Session{
|
||||
Password: []byte(apiutil.BasicAuthPrefix + creds),
|
||||
}
|
||||
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
|
||||
invalidEncodedCredsSession := session.Session{
|
||||
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
|
||||
}
|
||||
hcClientKeySession := session.Session{
|
||||
Password: []byte("Client " + clientKey),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
@@ -305,6 +415,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
chanID string
|
||||
domainID string
|
||||
clientID string
|
||||
authNToken string
|
||||
authNRes *grpcClientsV1.AuthnRes
|
||||
authNRes1 smqauthn.Session
|
||||
authNErr error
|
||||
@@ -322,6 +433,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
@@ -336,9 +448,10 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "subscribe with empty topics",
|
||||
@@ -358,13 +471,14 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "subscribe with unauthorized client key",
|
||||
session: &clientKeySession,
|
||||
session: &unauthorizedKeySession,
|
||||
topics: &topics,
|
||||
authKey: clientKey,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
@@ -398,10 +512,10 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "subscribe with unauthorized client key",
|
||||
desc: "subscribe with unauthorized token",
|
||||
session: &tokenSession,
|
||||
topics: &topics,
|
||||
authKey: token,
|
||||
@@ -415,9 +529,67 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with basic auth successfully",
|
||||
session: &basicAuthSession,
|
||||
topics: &topics,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with invalid basic auth",
|
||||
session: &invalidBasicAuthSession,
|
||||
topics: &topics,
|
||||
authKey: invalidValue,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "publish with b64 encoded credentials",
|
||||
session: &encodedCredsSession,
|
||||
topics: &topics,
|
||||
status: http.StatusOK,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
authNErr: nil,
|
||||
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish with invalid b64 encoded credentials",
|
||||
session: &invalidEncodedCredsSession,
|
||||
topics: &topics,
|
||||
authKey: invalidValue,
|
||||
clientType: policies.ClientType,
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "subscribe with health check topic successfully",
|
||||
session: &clientKeySession,
|
||||
session: &hcClientKeySession,
|
||||
topics: &[]string{hcTopic},
|
||||
authKey: clientKey,
|
||||
status: http.StatusOK,
|
||||
@@ -425,12 +597,13 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
chanID: chanID,
|
||||
domainID: domainID,
|
||||
clientID: clientID,
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with invalid health check topic",
|
||||
session: &clientKeySession,
|
||||
session: &hcClientKeySession,
|
||||
topics: &[]string{invalidHCTopic},
|
||||
authKey: clientKey,
|
||||
status: http.StatusBadRequest,
|
||||
@@ -449,7 +622,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
|
||||
tc.clientType = policies.UserType
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
ClientType: tc.clientType,
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewLogging(svc ws.Service, logger *slog.Logger) ws.Service {
|
||||
|
||||
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) (err error) {
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
@@ -45,7 +45,7 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey,
|
||||
lm.logger.Info("Subscribe completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, topicType, c)
|
||||
return lm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) (err error) {
|
||||
|
||||
@@ -32,13 +32,13 @@ func NewMetrics(svc ws.Service, counter metrics.Counter, latency metrics.Histogr
|
||||
}
|
||||
|
||||
// Subscribe instruments Subscribe method with metrics.
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) error {
|
||||
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "subscribe").Add(1)
|
||||
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, topicType, c)
|
||||
return mm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error {
|
||||
|
||||
@@ -32,11 +32,11 @@ func NewTracing(tracer trace.Tracer, svc ws.Service) ws.Service {
|
||||
}
|
||||
|
||||
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
|
||||
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, topicType messaging.TopicType, client *ws.Client) error {
|
||||
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *ws.Client) error {
|
||||
ctx, span := tm.tracer.Start(ctx, subscribeOP)
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, topicType, client)
|
||||
return tm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error {
|
||||
|
||||
Reference in New Issue
Block a user