SMQ-2866 - MQTT auth fails with identical secrets across different domains (#3030)
Continuous Delivery / Build and Push (push) Has been cancelled
Check the consistency of generated files / check-generated-files (push) Has been cancelled
Check License Header / check-license (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2025-07-31 14:55:37 +02:00
committed by GitHub
parent 71ea84c3ec
commit 9b77130f6e
25 changed files with 187 additions and 77 deletions
+6 -15
View File
@@ -27,8 +27,7 @@ const (
type AuthnReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
ClientSecret string `protobuf:"bytes,2,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -63,16 +62,9 @@ func (*AuthnReq) Descriptor() ([]byte, []int) {
return file_clients_v1_clients_proto_rawDescGZIP(), []int{0}
}
func (x *AuthnReq) GetClientId() string {
func (x *AuthnReq) GetToken() string {
if x != nil {
return x.ClientId
}
return ""
}
func (x *AuthnReq) GetClientSecret() string {
if x != nil {
return x.ClientSecret
return x.Token
}
return ""
}
@@ -294,10 +286,9 @@ var File_clients_v1_clients_proto protoreflect.FileDescriptor
const file_clients_v1_clients_proto_rawDesc = "" +
"\n" +
"\x18clients/v1/clients.proto\x12\n" +
"clients.v1\x1a\x16common/v1/common.proto\"L\n" +
"\bAuthnReq\x12\x1b\n" +
"\tclient_id\x18\x01 \x01(\tR\bclientId\x12#\n" +
"\rclient_secret\x18\x02 \x01(\tR\fclientSecret\"@\n" +
"clients.v1\x1a\x16common/v1/common.proto\" \n" +
"\bAuthnReq\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\"@\n" +
"\bAuthnRes\x12$\n" +
"\rauthenticated\x18\x01 \x01(\bR\rauthenticated\x12\x0e\n" +
"\x02id\x18\x02 \x01(\tR\x02id\"<\n" +
+2 -4
View File
@@ -111,8 +111,7 @@ func (client grpcClient) Authenticate(ctx context.Context, req *grpcClientsV1.Au
defer cancel()
res, err := client.authenticate(ctx, authenticateReq{
ClientID: req.GetClientId(),
ClientSecret: req.GetClientSecret(),
Token: req.GetToken(),
})
if err != nil {
return &grpcClientsV1.AuthnRes{}, decodeError(err)
@@ -125,8 +124,7 @@ func (client grpcClient) Authenticate(ctx context.Context, req *grpcClientsV1.Au
func encodeAuthenticateRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(authenticateReq)
return &grpcClientsV1.AuthnReq{
ClientId: req.ClientID,
ClientSecret: req.ClientSecret,
Token: req.Token,
}, nil
}
+1 -1
View File
@@ -14,7 +14,7 @@ import (
func authenticateEndpoint(svc pClients.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authenticateReq)
id, err := svc.Authenticate(ctx, req.ClientSecret)
id, err := svc.Authenticate(ctx, req.Token)
if err != nil {
return authenticateRes{}, err
}
+1 -1
View File
@@ -97,7 +97,7 @@ func TestAuthenticate(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Authenticate", mock.Anything, tc.clientSecret).Return(tc.clientID, tc.svcErr)
res, err := client.Authenticate(context.Background(), &grpcClientsV1.AuthnReq{ClientSecret: tc.clientSecret})
res, err := client.Authenticate(context.Background(), &grpcClientsV1.AuthnReq{Token: tc.clientSecret})
assert.True(t, errors.Contains(err, tc.err))
assert.Equal(t, tc.resp, res)
svcCall.Unset()
+1 -2
View File
@@ -4,8 +4,7 @@
package grpc
type authenticateReq struct {
ClientID string
ClientSecret string
Token string
}
type retrieveEntitiesReq struct {
+1 -2
View File
@@ -84,8 +84,7 @@ func (s *grpcServer) Authenticate(ctx context.Context, req *grpcClientsV1.AuthnR
func decodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcClientsV1.AuthnReq)
return authenticateReq{
ClientID: req.GetClientId(),
ClientSecret: req.GetClientSecret(),
Token: req.GetToken(),
}, nil
}
+4 -4
View File
@@ -38,8 +38,8 @@ func (tc *clientCache) Save(ctx context.Context, clientKey, clientID string) err
if clientKey == "" || clientID == "" {
return errors.Wrap(repoerr.ErrCreateEntity, errors.New("client key or client id is empty"))
}
tkey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
if err := tc.client.Set(ctx, tkey, clientID, tc.keyDuration).Err(); err != nil {
ckey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
if err := tc.client.Set(ctx, ckey, clientID, tc.keyDuration).Err(); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
@@ -56,8 +56,8 @@ func (tc *clientCache) ID(ctx context.Context, clientKey string) (string, error)
return "", repoerr.ErrNotFound
}
tkey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
clientID, err := tc.client.Get(ctx, tkey).Result()
ckey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
clientID, err := tc.client.Get(ctx, ckey).Result()
if err != nil {
return "", errors.Wrap(repoerr.ErrNotFound, err)
}
+5 -3
View File
@@ -67,8 +67,10 @@ type Repository interface {
// operation failure.
Save(ctx context.Context, client ...Client) ([]Client, error)
// RetrieveBySecret retrieves a client based on the secret (key).
RetrieveBySecret(ctx context.Context, key string) (Client, error)
// RetrieveBySecret retrieves a client based on the secret (key) and domainID.
// Domain ID is required because the key is not globally unique,
// but unique on the level of Domain.
RetrieveBySecret(ctx context.Context, key, id string, prefix authn.AuthPrefix) (Client, error)
AddConnections(ctx context.Context, conns []Connection) error
@@ -95,7 +97,7 @@ type Repository interface {
roles.Repository
}
// Service specifies an API that must be fullfiled by the domain service
// Service specifies an API that must be fulfilled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
// CreateClients creates new client. In case of the failed registration, a
+25 -12
View File
@@ -11,6 +11,7 @@ import (
"context"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/roles"
mock "github.com/stretchr/testify/mock"
)
@@ -1257,8 +1258,8 @@ func (_c *Repository_RetrieveByIds_Call) RunAndReturn(run func(ctx context.Conte
}
// RetrieveBySecret provides a mock function for the type Repository
func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string) (clients.Client, error) {
ret := _mock.Called(ctx, key)
func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string, id string, prefix authn.AuthPrefix) (clients.Client, error) {
ret := _mock.Called(ctx, key, id, prefix)
if len(ret) == 0 {
panic("no return value specified for RetrieveBySecret")
@@ -1266,16 +1267,16 @@ func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string) (clie
var r0 clients.Client
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (clients.Client, error)); ok {
return returnFunc(ctx, key)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, authn.AuthPrefix) (clients.Client, error)); ok {
return returnFunc(ctx, key, id, prefix)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) clients.Client); ok {
r0 = returnFunc(ctx, key)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, authn.AuthPrefix) clients.Client); ok {
r0 = returnFunc(ctx, key, id, prefix)
} else {
r0 = ret.Get(0).(clients.Client)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, key)
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, authn.AuthPrefix) error); ok {
r1 = returnFunc(ctx, key, id, prefix)
} else {
r1 = ret.Error(1)
}
@@ -1290,11 +1291,13 @@ type Repository_RetrieveBySecret_Call struct {
// RetrieveBySecret is a helper method to define mock.On call
// - ctx context.Context
// - key string
func (_e *Repository_Expecter) RetrieveBySecret(ctx interface{}, key interface{}) *Repository_RetrieveBySecret_Call {
return &Repository_RetrieveBySecret_Call{Call: _e.mock.On("RetrieveBySecret", ctx, key)}
// - id string
// - prefix authn.AuthPrefix
func (_e *Repository_Expecter) RetrieveBySecret(ctx interface{}, key interface{}, id interface{}, prefix interface{}) *Repository_RetrieveBySecret_Call {
return &Repository_RetrieveBySecret_Call{Call: _e.mock.On("RetrieveBySecret", ctx, key, id, prefix)}
}
func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, key string)) *Repository_RetrieveBySecret_Call {
func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, key string, id string, prefix authn.AuthPrefix)) *Repository_RetrieveBySecret_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -1304,9 +1307,19 @@ func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, ke
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 authn.AuthPrefix
if args[3] != nil {
arg3 = args[3].(authn.AuthPrefix)
}
run(
arg0,
arg1,
arg2,
arg3,
)
})
return _c
@@ -1317,7 +1330,7 @@ func (_c *Repository_RetrieveBySecret_Call) Return(client clients.Client, err er
return _c
}
func (_c *Repository_RetrieveBySecret_Call) RunAndReturn(run func(ctx context.Context, key string) (clients.Client, error)) *Repository_RetrieveBySecret_Call {
func (_c *Repository_RetrieveBySecret_Call) RunAndReturn(run func(ctx context.Context, key string, id string, prefix authn.AuthPrefix) (clients.Client, error)) *Repository_RetrieveBySecret_Call {
_c.Call.Return(run)
return _c
}
+12 -1
View File
@@ -14,6 +14,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
@@ -86,13 +87,23 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
return reClients, nil
}
func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key string) (clients.Client, error) {
func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key, id string, prefix authn.AuthPrefix) (clients.Client, error) {
q := fmt.Sprintf(`SELECT id, name, tags, COALESCE(domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, identity, secret, metadata, created_at, updated_at, updated_by, status
FROM clients
WHERE secret = :secret AND status = %d`, clients.EnabledStatus)
switch prefix {
case authn.DomainAuth:
q += " AND domain_id = :domain_id"
case authn.BasicAuth:
q += " AND id = :id"
default:
return clients.Client{}, repoerr.ErrNotFound
}
dbc := DBClient{
Secret: key,
Domain: id,
ID: id,
}
rows, err := repo.DB.NamedQueryContext(ctx, q, dbc)
+39 -3
View File
@@ -15,6 +15,7 @@ import (
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/clients/postgres"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
@@ -358,6 +359,7 @@ func TestClientsRetrieveBySecret(t *testing.T) {
Identity: clientIdentity,
Secret: testsutil.GenerateUUID(t),
},
Domain: testsutil.GenerateUUID(t),
Metadata: clients.Metadata{},
Status: clients.EnabledStatus,
}
@@ -368,17 +370,27 @@ func TestClientsRetrieveBySecret(t *testing.T) {
cases := []struct {
desc string
secret string
id string
response clients.Client
prefix authn.AuthPrefix
err error
}{
{
desc: "retrieve client by secret successfully",
desc: "retrieve client by secret with no id",
secret: client.Credentials.Secret,
response: clients.Client{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve client by client ID and secret successfully",
secret: client.Credentials.Secret,
id: client.ID,
prefix: authn.BasicAuth,
response: client,
err: nil,
},
{
desc: "retrieve client by invalid secret",
desc: "retrieve client by client ID invalid secret",
secret: "non-existent-secret",
response: clients.Client{},
err: repoerr.ErrNotFound,
@@ -389,10 +401,34 @@ func TestClientsRetrieveBySecret(t *testing.T) {
response: clients.Client{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve client by client ID and secret with an invalid ID type",
secret: client.Credentials.Secret,
id: client.ID,
prefix: authn.DomainAuth,
response: clients.Client{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve client by domain ID and secret successfully",
secret: client.Credentials.Secret,
id: client.Domain,
prefix: authn.DomainAuth,
response: client,
err: nil,
},
{
desc: "retrieve client by domain ID and secret with an invalid ID type",
secret: client.Credentials.Secret,
id: client.Domain,
prefix: authn.BasicAuth,
response: clients.Client{},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
res, err := repo.RetrieveBySecret(context.Background(), tc.secret)
res, err := repo.RetrieveBySecret(context.Background(), tc.secret, tc.id, tc.prefix)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, res, tc.response, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, res))
}
+9 -5
View File
@@ -7,6 +7,7 @@ import (
"context"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
@@ -47,17 +48,20 @@ type service struct {
policy policies.Service
}
func (svc service) Authenticate(ctx context.Context, key string) (string, error) {
id, err := svc.cache.ID(ctx, key)
func (svc service) Authenticate(ctx context.Context, token string) (string, error) {
id, err := svc.cache.ID(ctx, token)
if err == nil {
return id, nil
}
client, err := svc.repo.RetrieveBySecret(ctx, key)
prefix, id, key, err := authn.AuthUnpack(token)
if err != nil && err != authn.ErrNotEncoded {
return "", err
}
client, err := svc.repo.RetrieveBySecret(ctx, key, id, prefix)
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
if err := svc.cache.Save(ctx, key, client.ID); err != nil {
if err := svc.cache.Save(ctx, token, client.ID); err != nil {
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
+4 -3
View File
@@ -11,6 +11,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -59,7 +60,7 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
func (svc *adapterService) Publish(ctx context.Context, key string, msg *messaging.Message) error {
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
ClientSecret: key,
Token: authn.AuthPack(authn.DomainAuth, msg.GetDomain(), key),
})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
@@ -89,7 +90,7 @@ func (svc *adapterService) Publish(ctx context.Context, key string, msg *messagi
func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error {
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
ClientSecret: key,
Token: authn.AuthPack(authn.DomainAuth, domainID, key),
})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
@@ -126,7 +127,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID,
func (svc *adapterService) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error {
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
ClientSecret: key,
Token: authn.AuthPack(authn.DomainAuth, domainID, key),
})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
+1 -1
View File
@@ -260,7 +260,7 @@ func TestPublish(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: tc.key}).Return(tc.authnRes, tc.authnErr)
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr)
domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
DomainId: tc.domainID,
+1 -1
View File
@@ -130,7 +130,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
switch {
case strings.HasPrefix(string(s.Password), "Client"):
secret := strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret})
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
if err != nil {
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, err))
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
+1 -1
View File
@@ -341,7 +341,7 @@ func TestPublish(t *testing.T) {
if tc.topic != nil {
internalTopic = strings.TrimPrefix(strings.ReplaceAll(*tc.topic, "/", "."), ".m.")
}
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{ClientSecret: tc.password}).Return(tc.authNRes, tc.authNErr)
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.password)}).Return(tc.authNRes, tc.authNErr)
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr)
repoCall := publisher.On("Publish", ctx, internalTopic, mock.Anything).Return(tc.publishErr)
+1 -2
View File
@@ -37,8 +37,7 @@ service ClientsService {
message AuthnReq {
string client_id = 1;
string client_secret = 2;
string token = 1;
}
message AuthnRes {
+2 -1
View File
@@ -13,6 +13,7 @@ import (
"github.com/absmach/mgate/pkg/session"
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -86,7 +87,7 @@ func (h *handler) AuthConnect(ctx context.Context) error {
pwd := string(s.Password)
res, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: pwd})
res, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, s.Username, pwd)})
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
+4 -1
View File
@@ -19,6 +19,7 @@ import (
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/mqtt"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -145,11 +146,13 @@ func TestAuthConnect(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
password := ""
username := ""
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
password = string(tc.session.Password)
username = tc.session.Username
}
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr)
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, username, password)}).Return(tc.authNRes, tc.authNErr)
err := handler.AuthConnect(ctx)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
clientsCall.Unset()
+56
View File
@@ -5,6 +5,9 @@ package authn
import (
"context"
"encoding/base64"
"errors"
"strings"
)
type TokenType uint32
@@ -48,3 +51,56 @@ type Session struct {
type Authentication interface {
Authenticate(ctx context.Context, token string) (Session, error)
}
const authSep = ":"
type AuthPrefix int
const (
Unknown AuthPrefix = iota
BasicAuth
DomainAuth
)
var authPrefixStrings = [3]string{
"Unknown",
"Basic",
"Domain",
}
// String returns the string representation (e.g., "Basic") of the AuthPrefix.
func (a AuthPrefix) String() string {
if int(a) < len(authPrefixStrings) {
return authPrefixStrings[a]
}
return "Unknown"
}
// ErrNotEncoded acts similarly to EOF - it does indicate there is no suffix in
// the token, but that does not have to be treated as the error in some cases.
// If token is not base64-encoded, the token is returned as a key alongside with the error.
var ErrNotEncoded = errors.New("token is not encoded with suffix")
func AuthUnpack(token string) (AuthPrefix, string, string, error) {
var auth AuthPrefix
for i, pref := range authPrefixStrings {
if strings.HasPrefix(token, pref) {
token = token[len(pref):]
auth = AuthPrefix(i)
break
}
}
payload, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return Unknown, token, "", err
}
id, key, found := strings.Cut(string(payload), authSep)
if !found {
return auth, id, key, ErrNotEncoded
}
return auth, id, key, nil
}
func AuthPack(prefix AuthPrefix, id, key string) string {
return prefix.String() + base64.StdEncoding.EncodeToString([]byte(id+":"+key))
}
+3 -5
View File
@@ -9,6 +9,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -91,13 +92,10 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
// authorize checks if the clientKey is authorized to access the channel
// and returns the clientID if it is.
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: clientKey,
}
if strings.HasPrefix(clientKey, "Client") {
authnReq.ClientSecret = extractClientSecret(clientKey)
clientKey = extractClientSecret(clientKey)
}
authnRes, err := svc.clients.Authenticate(ctx, authnReq)
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, clientKey)})
if err != nil {
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
+3 -2
View File
@@ -15,6 +15,7 @@ import (
chmocks "github.com/absmach/supermq/channels/mocks"
climocks "github.com/absmach/supermq/clients/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -190,9 +191,9 @@ func TestSubscribe(t *testing.T) {
ClientID: clientID,
Handler: c,
}
authReq := &grpcClientsV1.AuthnReq{ClientSecret: tc.clientKey}
authReq := &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, tc.domainID, tc.clientKey)}
if strings.HasPrefix(tc.clientKey, "Client") {
authReq.ClientSecret = strings.TrimPrefix(tc.clientKey, "Client ")
authReq.Token = authn.AuthPack(authn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.clientKey, "Client "))
}
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
+2 -1
View File
@@ -130,7 +130,8 @@ func TestHandshake(t *testing.T) {
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool {
return req.ClientSecret == clientKey
_, _, key, _ := smqauthn.AuthUnpack(req.Token)
return key == clientKey
})).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
+3 -6
View File
@@ -16,6 +16,7 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -194,14 +195,10 @@ func (h *handler) Disconnect(ctx context.Context) error {
}
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) {
authnReq := &grpcClientsV1.AuthnReq{
ClientSecret: token,
}
if strings.HasPrefix(token, "Client") {
authnReq.ClientSecret = extractClientSecret(token)
token = extractClientSecret(token)
}
authnRes, err := h.clients.Authenticate(ctx, authnReq)
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, token)})
if err != nil {
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
}