SMQ-2799 - Add support for basic auth for HTTP and WS adapters (#3049)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-12-01 18:51:49 +03:00
committed by GitHub
parent 9d13d5b528
commit cd281c2589
13 changed files with 747 additions and 323 deletions
+9 -4
View File
@@ -8,11 +8,16 @@ import (
"strings"
)
// BearerPrefix represents the token prefix for Bearer authentication scheme.
const BearerPrefix = "Bearer "
const (
// BearerPrefix represents the token prefix for Bearer authentication scheme.
BearerPrefix = "Bearer "
// ClientPrefix represents the key prefix for Client authentication scheme.
const ClientPrefix = "Client "
// ClientPrefix represents the key prefix for Client authentication scheme.
ClientPrefix = "Client "
// BasicAuthPrefix represents the prefix for Basic authentication scheme.
BasicAuthPrefix = "Basic "
)
// ExtractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned.
func ExtractBearerToken(r *http.Request) string {
+71 -34
View File
@@ -5,6 +5,7 @@ package http
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"net/http"
@@ -26,19 +27,12 @@ import (
var _ session.Handler = (*handler)(nil)
type ctxKey string
const (
protocol = "http"
clientIDCtxKey ctxKey = "client_id"
clientTypeCtxKey ctxKey = "client_type"
)
const protocol = "http"
// Log message formats.
const (
logInfoPublished = "published with client_type %s client_id %s to the topic %s"
logInfoFailedAuthNToken = "failed to authenticate token for topic %s with error %s"
logInfoFailedAuthNClient = "failed to authenticate client key %s for topic %s with error %s"
publishedInfoFmt = "published with client_type %s client_id %s to the topic %s"
failedAuthnFmt = "failed to authenticate client_type %s for topic %s with error %s"
)
// Error wrappers for MQTT errors.
@@ -48,6 +42,8 @@ var (
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
errMalformedTopic = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("malformed topic"))
errMissingTopicPub = mgate.NewHTTPProxyError(http.StatusBadRequest, errors.New("failed to publish due to missing topic"))
errInvalidAuthFormat = errors.New("invalid basic auth format")
errInvalidClientType = errors.New("invalid client type")
)
// Event implements events.Event interface.
@@ -118,37 +114,39 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(errMalformedTopic, err)
}
var clientID, clientType string
var token, clientType string
pass := string(s.Password)
switch {
case strings.HasPrefix(string(s.Password), "Client"):
secret := strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
if err != nil {
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
if !authnRes.Authenticated {
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, svcerr.ErrAuthentication))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
case s.Username != "" && pass != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, s.Username, pass)
clientType = policies.ClientType
clientID = authnRes.GetId()
case strings.HasPrefix(string(s.Password), apiutil.BearerPrefix):
token := strings.TrimPrefix(string(s.Password), apiutil.BearerPrefix)
authnSession, err := h.authn.Authenticate(ctx, token)
case strings.HasPrefix(pass, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(pass, apiutil.BasicAuthPrefix))
if err != nil {
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNToken, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, policies.ClientType, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
case strings.HasPrefix(pass, apiutil.ClientPrefix):
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(pass, apiutil.ClientPrefix))
clientType = policies.ClientType
case strings.HasPrefix(pass, apiutil.BearerPrefix):
token = strings.TrimPrefix(pass, apiutil.BearerPrefix)
clientType = policies.UserType
clientID = authnSession.UserID
default:
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
id, err := h.authenticate(ctx, clientType, token)
if err != nil {
h.logger.Warn(fmt.Sprintf(failedAuthnFmt, clientType, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
// Health topics are not published to message broker.
if topicType == messaging.HealthType {
h.logger.Info(fmt.Sprintf(logInfoPublished, clientType, clientID, *topic))
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
return nil
}
@@ -157,21 +155,21 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
Domain: domainID,
Channel: channelID,
Subtopic: subtopic,
Publisher: clientID,
Publisher: id,
Payload: *payload,
Created: time.Now().UnixNano(),
}
ar := &grpcChannelsV1.AuthzReq{
DomainId: domainID,
ClientId: clientID,
ClientId: id,
ClientType: clientType,
ChannelId: msg.Channel,
Type: uint32(connections.Publish),
}
res, err := h.channels.Authorize(ctx, ar)
if err != nil {
return mgate.NewHTTPProxyError(http.StatusBadRequest, err)
return mgate.NewHTTPProxyError(http.StatusUnauthorized, err)
}
if !res.GetAuthorized() {
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthorization)
@@ -181,7 +179,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
return errors.Wrap(errFailedPublishToMsgBroker, err)
}
h.logger.Info(fmt.Sprintf(logInfoPublished, clientType, clientID, *topic))
h.logger.Info(fmt.Sprintf(publishedInfoFmt, clientType, id, *topic))
return nil
}
@@ -200,3 +198,42 @@ func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error {
func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
switch authType {
case policies.UserType:
authnSession, err := h.authn.Authenticate(ctx, token)
if err != nil {
return "", err
}
return authnSession.UserID, nil
case policies.ClientType:
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
return authnRes.GetId(), nil
default:
return "", errInvalidClientType
}
}
// decodeAuth decodes the base64 encoded string in the format "clientID:secret".
func decodeAuth(s string) (string, string, error) {
db, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(db), ":", 2)
if len(parts) != 2 {
return "", "", errInvalidAuthFormat
}
clientID := parts[0]
secret := parts[1]
return clientID, secret, nil
}
+142 -76
View File
@@ -5,6 +5,7 @@ package http_test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
@@ -145,21 +146,34 @@ func TestPublish(t *testing.T) {
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
cases := []struct {
desc string
topic *string
channelID string
username string
payload *[]byte
password string
session *session.Session
status int
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
@@ -169,17 +183,18 @@ func TestPublish(t *testing.T) {
err error
}{
{
desc: "publish with key successfully",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
desc: "publish with key successfully",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with token successfully",
@@ -207,17 +222,18 @@ func TestPublish(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "publish with key and subtopic successfully",
topic: &subtopic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
desc: "publish with key and subtopic successfully",
topic: &subtopic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
err: nil,
},
{
desc: "publish with empty topic",
@@ -238,14 +254,15 @@ func TestPublish(t *testing.T) {
err: errClientNotInitialized,
},
{
desc: "publish with invalid topic",
topic: &invalidTopic,
status: http.StatusBadRequest,
password: clientKey,
session: &clientKeySession,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
err: errMalformedTopic,
desc: "publish with invalid topic",
topic: &invalidTopic,
status: http.StatusBadRequest,
password: clientKey,
session: &clientKeySession,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
err: errMalformedTopic,
},
{
desc: "publish with malformwd subtopic",
@@ -269,28 +286,30 @@ func TestPublish(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "publish with client key and failed to authenticate",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusUnauthorized,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: nil,
err: svcerr.ErrAuthentication,
desc: "publish with client key and failed to authenticate",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusUnauthorized,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: nil,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with client key and failed to authenticate with error",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusUnauthorized,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
desc: "publish with client key and failed to authenticate with error",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
status: http.StatusUnauthorized,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with token and failed to authenticate",
@@ -305,32 +324,78 @@ func TestPublish(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "publish with unauthorized client",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
status: http.StatusUnauthorized,
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: nil,
err: svcerr.ErrAuthorization,
desc: "publish with basic auth",
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &basicAuthSession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
},
{
desc: "publish with authorization error",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
status: http.StatusBadRequest,
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
desc: "publish with invalid basic auth",
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &invalidBasicAuthSession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with encoded credentials",
topic: &topic,
payload: &payload,
username: clientID,
password: clientKey,
session: &encodedCredsSession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNRes1: smqauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
authZErr: nil,
},
{
desc: "publish with unauthorized client",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
status: http.StatusUnauthorized,
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: nil,
err: svcerr.ErrAuthorization,
},
{
desc: "publish with authorization error",
topic: &topic,
payload: &payload,
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
status: http.StatusUnauthorized,
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "publish with failed to publish",
@@ -339,6 +404,7 @@ func TestPublish(t *testing.T) {
password: clientKey,
session: &clientKeySession,
channelID: chanID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -375,7 +441,7 @@ func TestPublish(t *testing.T) {
if tc.topic != nil {
internalTopic = strings.TrimPrefix(strings.ReplaceAll(*tc.topic, "/", "."), ".m.")
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.password)}).Return(tc.authNRes, tc.authNErr)
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr)
repoCall := publisher.On("Publish", ctx, internalTopic, mock.Anything).Return(tc.publishErr)
-4
View File
@@ -121,7 +121,6 @@ func TestSendMessage(t *testing.T) {
msg: msg,
secret: "",
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
authErr: svcerr.ErrAuthentication,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
@@ -132,7 +131,6 @@ func TestSendMessage(t *testing.T) {
msg: msg,
secret: "invalid",
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
authErr: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
@@ -143,7 +141,6 @@ func TestSendMessage(t *testing.T) {
msg: msg,
secret: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
authErr: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
@@ -176,7 +173,6 @@ func TestSendMessage(t *testing.T) {
msg: msg,
secret: clientKey,
authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""},
authErr: svcerr.ErrAuthentication,
svcErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
+49 -23
View File
@@ -33,7 +33,7 @@ type Service interface {
// the channelID for subscription and domainID specifies the domain for authorization.
// Subtopic is optional.
// If the subscription is successful, nil is returned otherwise error is returned.
Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, topicType messaging.TopicType, client *Client) error
Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *Client) error
Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error
}
@@ -57,12 +57,12 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
}
}
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, clientKey, domainID, channelID, subtopic string, topicType messaging.TopicType, c *Client) error {
if (channelID == "" && topicType != messaging.HealthType) || clientKey == "" || domainID == "" {
func (svc *adapterService) Subscribe(ctx context.Context, sessionID, username, password, domainID, channelID, subtopic string, topicType messaging.TopicType, c *Client) error {
if (channelID == "" && topicType != messaging.HealthType) || password == "" || domainID == "" {
return svcerr.ErrAuthentication
}
clientID, err := svc.authorize(ctx, clientKey, domainID, channelID, connections.Subscribe, topicType)
clientID, err := svc.authorize(ctx, username, password, domainID, channelID, connections.Subscribe, topicType)
if err != nil {
return svcerr.ErrAuthorization
}
@@ -102,38 +102,41 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
// authorize checks if the authKey is authorized to access the channel
// and returns the clientID or userID if it is.
func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var clientID, clientType string
func (svc *adapterService) authorize(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var token, clientType string
var err error
switch {
case strings.HasPrefix(authKey, apiutil.BearerPrefix):
token := strings.TrimPrefix(authKey, apiutil.BearerPrefix)
authnSession, err := svc.authn.Authenticate(ctx, token)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
case strings.HasPrefix(password, apiutil.BearerPrefix):
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
clientType = policies.UserType
clientID = authnSession.UserID
default:
secret := strings.TrimPrefix(authKey, apiutil.ClientPrefix)
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
case username != "" && password != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
clientID = authnRes.GetId()
default:
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
clientType = policies.ClientType
}
id, err := svc.authenticate(ctx, clientType, token)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
// Health check topics do not require channel authorization.
if topicType == messaging.HealthType {
return clientID, nil
return id, nil
}
authzReq := &grpcChannelsV1.AuthzReq{
ClientType: clientType,
ClientId: clientID,
ClientId: id,
Type: uint32(msgType),
ChannelId: chanID,
DomainId: domainID,
@@ -146,5 +149,28 @@ func (svc *adapterService) authorize(ctx context.Context, authKey, domainID, cha
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
return clientID, nil
return id, nil
}
func (svc *adapterService) authenticate(ctx context.Context, authType, token string) (string, error) {
switch authType {
case policies.UserType:
authnSession, err := svc.authn.Authenticate(ctx, token)
if err != nil {
return "", err
}
return authnSession.UserID, nil
case policies.ClientType:
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
if err != nil {
return "", err
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
return authnRes.GetId(), nil
default:
return "", errInvalidClientType
}
}
+194 -128
View File
@@ -5,6 +5,7 @@ package ws_test
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"strings"
@@ -53,7 +54,9 @@ var (
Protocol: protocol,
Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`),
}
sessionID = "sessionID"
sessionID = "sessionID"
validEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
invalidEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", invalidID, invalidKey)))
)
func newService() (ws.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) {
@@ -72,13 +75,15 @@ func TestSubscribe(t *testing.T) {
cases := []struct {
desc string
authKey string
username string
password string
chanID string
domainID string
subtopic string
clientType string
clientID string
topicType messaging.TopicType
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNErr error
authNRes1 smqauthn.Session
@@ -88,20 +93,21 @@ func TestSubscribe(t *testing.T) {
err error
}{
{
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
desc: "subscribe to channel with valid clientKey, chanID, subtopic",
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with valid token, chanID, subtopic",
authKey: token,
password: token,
chanID: chanID,
domainID: domainID,
clientID: userID,
@@ -113,7 +119,7 @@ func TestSubscribe(t *testing.T) {
},
{
desc: "subscribe to channel with invalid token",
authKey: invalidToken,
password: invalidToken,
chanID: chanID,
domainID: domainID,
subtopic: subTopic,
@@ -123,45 +129,48 @@ func TestSubscribe(t *testing.T) {
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
desc: "subscribe again to channel with valid clientKey, chanID, subtopic",
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with subscribe set to fail",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
subErr: ws.ErrFailedSubscription,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: ws.ErrFailedSubscription,
desc: "subscribe to channel with subscribe set to fail",
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
subErr: ws.ErrFailedSubscription,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: ws.ErrFailedSubscription,
},
{
desc: "subscribe to channel with invalid clientKey",
authKey: invalidKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with invalid clientKey",
password: invalidKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with empty channel",
authKey: clientKey,
password: clientKey,
chanID: "",
domainID: domainID,
clientID: clientID,
@@ -171,7 +180,7 @@ func TestSubscribe(t *testing.T) {
},
{
desc: "subscribe to channel with empty clientKey",
authKey: "",
password: "",
chanID: chanID,
domainID: domainID,
clientID: clientID,
@@ -181,7 +190,7 @@ func TestSubscribe(t *testing.T) {
},
{
desc: "subscribe to channel with empty clientKey and empty channel",
authKey: "",
password: "",
chanID: "",
domainID: domainID,
clientID: clientID,
@@ -190,66 +199,125 @@ func TestSubscribe(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe to channel with invalid channel",
authKey: clientKey,
chanID: invalidID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with invalid channel",
password: clientKey,
chanID: invalidID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
authZErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with failed authentication",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with failed authentication",
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with failed authorization",
authKey: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
err: svcerr.ErrAuthorization,
desc: "subscribe to channel with failed authorization",
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
authKey: "Client " + clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic",
password: "Client " + clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to health check topic with empty channel and valid clientKey",
authKey: clientKey,
chanID: "",
domainID: domainID,
clientID: clientID,
topicType: messaging.HealthType,
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
err: nil,
desc: "subscribe to channel with basic auth",
username: clientID,
password: clientKey,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with basic auth and invalid credentials",
username: invalidID,
password: invalidKey,
chanID: chanID,
domainID: domainID,
clientID: invalidID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to channel with b64 encoded credentials",
password: apiutil.BasicAuthPrefix + validEncodedCreds,
chanID: chanID,
domainID: domainID,
clientID: clientID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe to channel with b64 encoded credentials and invalid credentials",
password: apiutil.BasicAuthPrefix + invalidEncodedCreds,
chanID: chanID,
domainID: domainID,
clientID: invalidID,
subtopic: subTopic,
topicType: messaging.MessageType,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
authNErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthorization,
},
{
desc: "subscribe to health check topic with empty channel and valid clientKey",
password: clientKey,
chanID: "",
domainID: domainID,
clientID: clientID,
topicType: messaging.HealthType,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
err: nil,
},
{
desc: "subscribe to health check topic with empty channel and valid token",
authKey: token,
password: token,
chanID: "",
domainID: domainID,
clientID: userID,
@@ -259,7 +327,7 @@ func TestSubscribe(t *testing.T) {
},
{
desc: "subscribe to health check topic with empty domain and valid clientKey",
authKey: clientKey,
password: clientKey,
chanID: "",
domainID: "",
clientID: clientID,
@@ -270,35 +338,33 @@ func TestSubscribe(t *testing.T) {
}
for _, tc := range cases {
subConfig := messaging.SubscriberConfig{
ID: sessionID,
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
ClientID: tc.clientID,
Handler: c,
}
authReq := &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.authKey)}
tc.clientType = policies.ClientType
if strings.HasPrefix(tc.authKey, "Client") {
authReq.Token = smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.authKey, "Client "))
}
if strings.HasPrefix(tc.authKey, apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.authKey, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), sessionID, tc.authKey, tc.domainID, tc.chanID, tc.subtopic, tc.topicType, c)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
clientsCall.Unset()
authCall.Unset()
channelsCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
subConfig := messaging.SubscriberConfig{
ID: sessionID,
Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic,
ClientID: tc.clientID,
Handler: c,
}
tc.clientType = policies.ClientType
if strings.HasPrefix(tc.password, apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.password, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
ClientId: tc.clientID,
Type: uint32(connections.Subscribe),
ChannelId: tc.chanID,
DomainId: tc.domainID,
}).Return(tc.authZRes, tc.authZErr)
repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr)
err := svc.Subscribe(context.Background(), sessionID, tc.username, tc.password, tc.domainID, tc.chanID, tc.subtopic, tc.topicType, c)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
clientsCall.Unset()
authCall.Unset()
channelsCall.Unset()
})
}
}
+16 -7
View File
@@ -17,6 +17,11 @@ import (
"github.com/go-chi/chi/v5"
)
const (
authzHeaderKey = "Authorization"
authzQueryKey = "authorization"
)
var errGenSessionID = errors.New("failed to generate session id")
func generateSessionID() (string, error) {
@@ -56,7 +61,7 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
go client.Start(ctx)
if err := svc.Subscribe(ctx, sessionID, req.authKey, req.domainID, req.channelID, req.subtopic, topicType, client); err != nil {
if err := svc.Subscribe(ctx, sessionID, req.username, req.password, req.domainID, req.channelID, req.subtopic, topicType, client); err != nil {
conn.Close()
return
}
@@ -66,14 +71,17 @@ func handshake(ctx context.Context, svc ws.Service, resolver messaging.TopicReso
}
func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) {
authKey := r.Header.Get("Authorization")
if authKey == "" {
authKeys := r.URL.Query()["authorization"]
if len(authKeys) == 0 {
username, password, ok := r.BasicAuth()
if !ok {
switch {
case r.URL.Query().Get(authzQueryKey) != "":
password = r.URL.Query().Get(authzQueryKey)
case r.Header.Get(authzHeaderKey) != "":
password = r.Header.Get(authzHeaderKey)
default:
logger.Debug("Missing authorization key.")
return connReq{}, errUnauthorizedAccess
}
authKey = authKeys[0]
}
domain := chi.URLParam(r, "domain")
@@ -85,7 +93,8 @@ func decodeRequest(r *http.Request, resolver messaging.TopicResolver, logger *sl
}
req := connReq{
authKey: authKey,
username: username,
password: password,
channelID: channelID,
domainID: domainID,
}
+2 -1
View File
@@ -4,7 +4,8 @@
package api
type connReq struct {
authKey string
username string
password string
channelID string
domainID string
subtopic string
+69 -24
View File
@@ -5,6 +5,7 @@ package ws
import (
"context"
"encoding/base64"
"fmt"
"log/slog"
"net/http"
@@ -43,6 +44,8 @@ var (
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
errFailedPublish = errors.New("failed to publish")
errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker")
errInvalidAuthFormat = errors.New("invalid basic auth format")
errInvalidClientType = errors.New("invalid client type")
)
// Event implements events.Event interface.
@@ -89,7 +92,7 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err))
}
clientID, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish, topicType)
clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType)
if err != nil {
return err
}
@@ -117,7 +120,7 @@ func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error {
if err != nil {
return err
}
if _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
if _, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil {
return err
}
}
@@ -188,38 +191,41 @@ func (h *handler) Disconnect(ctx context.Context) error {
return nil
}
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var clientID, clientType string
func (h *handler) authAccess(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) {
var token, clientType string
var err error
switch {
case strings.HasPrefix(token, apiutil.BearerPrefix):
token := strings.TrimPrefix(token, apiutil.BearerPrefix)
authnSession, err := h.authn.Authenticate(ctx, token)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
case strings.HasPrefix(password, apiutil.BearerPrefix):
token = strings.TrimPrefix(password, apiutil.BearerPrefix)
clientType = policies.UserType
clientID = authnSession.UserID
default:
secret := strings.TrimPrefix(token, apiutil.ClientPrefix)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
if !authnRes.Authenticated {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
case username != "" && password != "":
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
clientID = authnRes.GetId()
case strings.HasPrefix(password, apiutil.BasicAuthPrefix):
username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix))
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password)
clientType = policies.ClientType
default:
token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix))
clientType = policies.ClientType
}
id, err := h.authenticate(ctx, clientType, token)
if err != nil {
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}
// Health check topics do not require channel authorization.
if topicType == messaging.HealthType {
return clientID, nil
return id, nil
}
ar := &grpcChannelsV1.AuthzReq{
Type: uint32(msgType),
ClientId: clientID,
ClientId: id,
ClientType: clientType,
ChannelId: chanID,
DomainId: domainID,
@@ -232,5 +238,44 @@ func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string
return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
}
return clientID, nil
return id, nil
}
func (h *handler) authenticate(ctx context.Context, authType, token string) (string, error) {
switch authType {
case policies.UserType:
authnSession, err := h.authn.Authenticate(ctx, token)
if err != nil {
return "", err
}
return authnSession.UserID, nil
case policies.ClientType:
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token})
if err != nil {
return "", err
}
if !authnRes.Authenticated {
return "", svcerr.ErrAuthentication
}
return authnRes.GetId(), nil
default:
return "", errInvalidClientType
}
}
// decodeAuth decodes the base64 encoded string in the format "clientID:secret".
func decodeAuth(s string) (string, string, error) {
db, err := base64.URLEncoding.DecodeString(s)
if err != nil {
return "", "", err
}
parts := strings.SplitN(string(db), ":", 2)
if len(parts) != 2 {
return "", "", errInvalidAuthFormat
}
clientID := parts[0]
secret := parts[1]
return clientID, secret, nil
}
+189 -16
View File
@@ -5,6 +5,7 @@ package ws_test
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
@@ -82,16 +83,37 @@ func TestAuthPublish(t *testing.T) {
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
@@ -104,6 +126,7 @@ func TestAuthPublish(t *testing.T) {
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
@@ -122,6 +145,7 @@ func TestAuthPublish(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -137,9 +161,10 @@ func TestAuthPublish(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with nil session",
@@ -159,7 +184,7 @@ func TestAuthPublish(t *testing.T) {
},
{
desc: "publish with unauthorized client key",
session: &clientKeySession,
session: &unauthorizedKeySession,
topic: &topic,
authKey: clientKey,
payload: &payload,
@@ -167,6 +192,7 @@ func TestAuthPublish(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
@@ -202,10 +228,10 @@ func TestAuthPublish(t *testing.T) {
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with unauthorized client key",
desc: "publish with unauthorized token",
session: &tokenSession,
topic: &topic,
authKey: token,
@@ -220,9 +246,71 @@ func TestAuthPublish(t *testing.T) {
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "publish with basic auth successfully",
session: &basicAuthSession,
topic: &topic,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid basic auth",
session: &invalidBasicAuthSession,
topic: &topic,
payload: &payload,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topic: &topic,
payload: &payload,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topic: &topic,
payload: &payload,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with health check topic successfully",
session: &clientKeySession,
session: &hcClientKeySession,
topic: &hcTopic,
authKey: clientKey,
payload: &payload,
@@ -231,12 +319,13 @@ func TestAuthPublish(t *testing.T) {
chanID: "",
domainID: domainID,
clientID: userID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: true},
authNErr: nil,
},
{
desc: "publish with invalid health check topic",
session: &clientKeySession,
session: &hcClientKeySession,
topic: &invalidHCTopic,
authKey: clientKey,
payload: &payload,
@@ -256,7 +345,7 @@ func TestAuthPublish(t *testing.T) {
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
@@ -284,16 +373,37 @@ func TestAuthSubscribe(t *testing.T) {
clientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
unauthorizedKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
invalidClientKeySession := session.Session{
Password: []byte("Client " + invalidKey),
}
tokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + validToken),
}
invalidTokenSession := session.Session{
Password: []byte(apiutil.BearerPrefix + invalidToken),
}
basicAuthSession := session.Session{
Username: clientID,
Password: []byte(clientKey),
}
invalidBasicAuthSession := session.Session{
Username: clientID,
Password: []byte(invalidValue),
}
creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey)))
encodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + creds),
}
invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue)))
invalidEncodedCredsSession := session.Session{
Password: []byte(apiutil.BasicAuthPrefix + invalidCreds),
}
hcClientKeySession := session.Session{
Password: []byte("Client " + clientKey),
}
tests := []struct {
desc string
@@ -305,6 +415,7 @@ func TestAuthSubscribe(t *testing.T) {
chanID string
domainID string
clientID string
authNToken string
authNRes *grpcClientsV1.AuthnRes
authNRes1 smqauthn.Session
authNErr error
@@ -322,6 +433,7 @@ func TestAuthSubscribe(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
@@ -336,9 +448,10 @@ func TestAuthSubscribe(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "subscribe with empty topics",
@@ -358,13 +471,14 @@ func TestAuthSubscribe(t *testing.T) {
},
{
desc: "subscribe with unauthorized client key",
session: &clientKeySession,
session: &unauthorizedKeySession,
topics: &topics,
authKey: clientKey,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: false},
@@ -398,10 +512,10 @@ func TestAuthSubscribe(t *testing.T) {
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "subscribe with unauthorized client key",
desc: "subscribe with unauthorized token",
session: &tokenSession,
topics: &topics,
authKey: token,
@@ -415,9 +529,67 @@ func TestAuthSubscribe(t *testing.T) {
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with basic auth successfully",
session: &basicAuthSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "subscribe with invalid basic auth",
session: &invalidBasicAuthSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "publish with b64 encoded credentials",
session: &encodedCredsSession,
topics: &topics,
status: http.StatusOK,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
authNErr: nil,
authZRes: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "publish with invalid b64 encoded credentials",
session: &invalidEncodedCredsSession,
topics: &topics,
authKey: invalidValue,
clientType: policies.ClientType,
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "subscribe with health check topic successfully",
session: &clientKeySession,
session: &hcClientKeySession,
topics: &[]string{hcTopic},
authKey: clientKey,
status: http.StatusOK,
@@ -425,12 +597,13 @@ func TestAuthSubscribe(t *testing.T) {
chanID: chanID,
domainID: domainID,
clientID: clientID,
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey),
authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true},
err: nil,
},
{
desc: "subscribe with invalid health check topic",
session: &clientKeySession,
session: &hcClientKeySession,
topics: &[]string{invalidHCTopic},
authKey: clientKey,
status: http.StatusBadRequest,
@@ -449,7 +622,7 @@ func TestAuthSubscribe(t *testing.T) {
if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) {
tc.clientType = policies.UserType
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.authKey)}).Return(tc.authNRes, tc.authNErr)
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
ClientType: tc.clientType,
+2 -2
View File
@@ -26,7 +26,7 @@ func NewLogging(svc ws.Service, logger *slog.Logger) ws.Service {
// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) (err error) {
func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -45,7 +45,7 @@ func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, authKey,
lm.logger.Info("Subscribe completed successfully", args...)
}(time.Now())
return lm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, topicType, c)
return lm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c)
}
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) (err error) {
+2 -2
View File
@@ -32,13 +32,13 @@ func NewMetrics(svc ws.Service, counter metrics.Counter, latency metrics.Histogr
}
// Subscribe instruments Subscribe method with metrics.
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, authKey, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) error {
func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *ws.Client) error {
defer func(begin time.Time) {
mm.counter.With("method", "subscribe").Add(1)
mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Subscribe(ctx, sessionID, authKey, domainID, chanID, subtopic, topicType, c)
return mm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c)
}
func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error {
+2 -2
View File
@@ -32,11 +32,11 @@ func NewTracing(tracer trace.Tracer, svc ws.Service) ws.Service {
}
// Subscribe traces the "Subscribe" operation of the wrapped ws.Service.
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, clientKey, domainID, chanID, subtopic string, topicType messaging.TopicType, client *ws.Client) error {
func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *ws.Client) error {
ctx, span := tm.tracer.Start(ctx, subscribeOP)
defer span.End()
return tm.svc.Subscribe(ctx, sessionID, clientKey, domainID, chanID, subtopic, topicType, client)
return tm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)
}
func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error {