mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:50:18 +00:00
SMQ-2866 - MQTT auth fails with identical secrets across different domains (#3030)
Continuous Delivery / Build and Push (push) Has been cancelled
Check the consistency of generated files / check-generated-files (push) Has been cancelled
Check License Header / check-license (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
Continuous Delivery / Build and Push (push) Has been cancelled
Check the consistency of generated files / check-generated-files (push) Has been cancelled
Check License Header / check-license (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -27,8 +27,7 @@ const (
|
||||
|
||||
type AuthnReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
ClientId string `protobuf:"bytes,1,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
|
||||
ClientSecret string `protobuf:"bytes,2,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
|
||||
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -63,16 +62,9 @@ func (*AuthnReq) Descriptor() ([]byte, []int) {
|
||||
return file_clients_v1_clients_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *AuthnReq) GetClientId() string {
|
||||
func (x *AuthnReq) GetToken() string {
|
||||
if x != nil {
|
||||
return x.ClientId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AuthnReq) GetClientSecret() string {
|
||||
if x != nil {
|
||||
return x.ClientSecret
|
||||
return x.Token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -294,10 +286,9 @@ var File_clients_v1_clients_proto protoreflect.FileDescriptor
|
||||
const file_clients_v1_clients_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x18clients/v1/clients.proto\x12\n" +
|
||||
"clients.v1\x1a\x16common/v1/common.proto\"L\n" +
|
||||
"\bAuthnReq\x12\x1b\n" +
|
||||
"\tclient_id\x18\x01 \x01(\tR\bclientId\x12#\n" +
|
||||
"\rclient_secret\x18\x02 \x01(\tR\fclientSecret\"@\n" +
|
||||
"clients.v1\x1a\x16common/v1/common.proto\" \n" +
|
||||
"\bAuthnReq\x12\x14\n" +
|
||||
"\x05token\x18\x01 \x01(\tR\x05token\"@\n" +
|
||||
"\bAuthnRes\x12$\n" +
|
||||
"\rauthenticated\x18\x01 \x01(\bR\rauthenticated\x12\x0e\n" +
|
||||
"\x02id\x18\x02 \x01(\tR\x02id\"<\n" +
|
||||
|
||||
@@ -111,8 +111,7 @@ func (client grpcClient) Authenticate(ctx context.Context, req *grpcClientsV1.Au
|
||||
defer cancel()
|
||||
|
||||
res, err := client.authenticate(ctx, authenticateReq{
|
||||
ClientID: req.GetClientId(),
|
||||
ClientSecret: req.GetClientSecret(),
|
||||
Token: req.GetToken(),
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcClientsV1.AuthnRes{}, decodeError(err)
|
||||
@@ -125,8 +124,7 @@ func (client grpcClient) Authenticate(ctx context.Context, req *grpcClientsV1.Au
|
||||
func encodeAuthenticateRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(authenticateReq)
|
||||
return &grpcClientsV1.AuthnReq{
|
||||
ClientId: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
Token: req.Token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
func authenticateEndpoint(svc pClients.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request interface{}) (interface{}, error) {
|
||||
req := request.(authenticateReq)
|
||||
id, err := svc.Authenticate(ctx, req.ClientSecret)
|
||||
id, err := svc.Authenticate(ctx, req.Token)
|
||||
if err != nil {
|
||||
return authenticateRes{}, err
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestAuthenticate(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Authenticate", mock.Anything, tc.clientSecret).Return(tc.clientID, tc.svcErr)
|
||||
res, err := client.Authenticate(context.Background(), &grpcClientsV1.AuthnReq{ClientSecret: tc.clientSecret})
|
||||
res, err := client.Authenticate(context.Background(), &grpcClientsV1.AuthnReq{Token: tc.clientSecret})
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
assert.Equal(t, tc.resp, res)
|
||||
svcCall.Unset()
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
package grpc
|
||||
|
||||
type authenticateReq struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Token string
|
||||
}
|
||||
|
||||
type retrieveEntitiesReq struct {
|
||||
|
||||
@@ -84,8 +84,7 @@ func (s *grpcServer) Authenticate(ctx context.Context, req *grpcClientsV1.AuthnR
|
||||
func decodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*grpcClientsV1.AuthnReq)
|
||||
return authenticateReq{
|
||||
ClientID: req.GetClientId(),
|
||||
ClientSecret: req.GetClientSecret(),
|
||||
Token: req.GetToken(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+4
-4
@@ -38,8 +38,8 @@ func (tc *clientCache) Save(ctx context.Context, clientKey, clientID string) err
|
||||
if clientKey == "" || clientID == "" {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, errors.New("client key or client id is empty"))
|
||||
}
|
||||
tkey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
|
||||
if err := tc.client.Set(ctx, tkey, clientID, tc.keyDuration).Err(); err != nil {
|
||||
ckey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
|
||||
if err := tc.client.Set(ctx, ckey, clientID, tc.keyDuration).Err(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
@@ -56,8 +56,8 @@ func (tc *clientCache) ID(ctx context.Context, clientKey string) (string, error)
|
||||
return "", repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
tkey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
|
||||
clientID, err := tc.client.Get(ctx, tkey).Result()
|
||||
ckey := fmt.Sprintf("%s:%s", keyPrefix, clientKey)
|
||||
clientID, err := tc.client.Get(ctx, ckey).Result()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(repoerr.ErrNotFound, err)
|
||||
}
|
||||
+5
-3
@@ -67,8 +67,10 @@ type Repository interface {
|
||||
// operation failure.
|
||||
Save(ctx context.Context, client ...Client) ([]Client, error)
|
||||
|
||||
// RetrieveBySecret retrieves a client based on the secret (key).
|
||||
RetrieveBySecret(ctx context.Context, key string) (Client, error)
|
||||
// RetrieveBySecret retrieves a client based on the secret (key) and domainID.
|
||||
// Domain ID is required because the key is not globally unique,
|
||||
// but unique on the level of Domain.
|
||||
RetrieveBySecret(ctx context.Context, key, id string, prefix authn.AuthPrefix) (Client, error)
|
||||
|
||||
AddConnections(ctx context.Context, conns []Connection) error
|
||||
|
||||
@@ -95,7 +97,7 @@ type Repository interface {
|
||||
roles.Repository
|
||||
}
|
||||
|
||||
// Service specifies an API that must be fullfiled by the domain service
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
type Service interface {
|
||||
// CreateClients creates new client. In case of the failed registration, a
|
||||
|
||||
+25
-12
@@ -11,6 +11,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
@@ -1257,8 +1258,8 @@ func (_c *Repository_RetrieveByIds_Call) RunAndReturn(run func(ctx context.Conte
|
||||
}
|
||||
|
||||
// RetrieveBySecret provides a mock function for the type Repository
|
||||
func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string) (clients.Client, error) {
|
||||
ret := _mock.Called(ctx, key)
|
||||
func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string, id string, prefix authn.AuthPrefix) (clients.Client, error) {
|
||||
ret := _mock.Called(ctx, key, id, prefix)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveBySecret")
|
||||
@@ -1266,16 +1267,16 @@ func (_mock *Repository) RetrieveBySecret(ctx context.Context, key string) (clie
|
||||
|
||||
var r0 clients.Client
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (clients.Client, error)); ok {
|
||||
return returnFunc(ctx, key)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, authn.AuthPrefix) (clients.Client, error)); ok {
|
||||
return returnFunc(ctx, key, id, prefix)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) clients.Client); ok {
|
||||
r0 = returnFunc(ctx, key)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, authn.AuthPrefix) clients.Client); ok {
|
||||
r0 = returnFunc(ctx, key, id, prefix)
|
||||
} else {
|
||||
r0 = ret.Get(0).(clients.Client)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, key)
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, authn.AuthPrefix) error); ok {
|
||||
r1 = returnFunc(ctx, key, id, prefix)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -1290,11 +1291,13 @@ type Repository_RetrieveBySecret_Call struct {
|
||||
// RetrieveBySecret is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - key string
|
||||
func (_e *Repository_Expecter) RetrieveBySecret(ctx interface{}, key interface{}) *Repository_RetrieveBySecret_Call {
|
||||
return &Repository_RetrieveBySecret_Call{Call: _e.mock.On("RetrieveBySecret", ctx, key)}
|
||||
// - id string
|
||||
// - prefix authn.AuthPrefix
|
||||
func (_e *Repository_Expecter) RetrieveBySecret(ctx interface{}, key interface{}, id interface{}, prefix interface{}) *Repository_RetrieveBySecret_Call {
|
||||
return &Repository_RetrieveBySecret_Call{Call: _e.mock.On("RetrieveBySecret", ctx, key, id, prefix)}
|
||||
}
|
||||
|
||||
func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, key string)) *Repository_RetrieveBySecret_Call {
|
||||
func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, key string, id string, prefix authn.AuthPrefix)) *Repository_RetrieveBySecret_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
@@ -1304,9 +1307,19 @@ func (_c *Repository_RetrieveBySecret_Call) Run(run func(ctx context.Context, ke
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 authn.AuthPrefix
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(authn.AuthPrefix)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
@@ -1317,7 +1330,7 @@ func (_c *Repository_RetrieveBySecret_Call) Return(client clients.Client, err er
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_RetrieveBySecret_Call) RunAndReturn(run func(ctx context.Context, key string) (clients.Client, error)) *Repository_RetrieveBySecret_Call {
|
||||
func (_c *Repository_RetrieveBySecret_Call) RunAndReturn(run func(ctx context.Context, key string, id string, prefix authn.AuthPrefix) (clients.Client, error)) *Repository_RetrieveBySecret_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
@@ -86,13 +87,23 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
|
||||
return reClients, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key string) (clients.Client, error) {
|
||||
func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key, id string, prefix authn.AuthPrefix) (clients.Client, error) {
|
||||
q := fmt.Sprintf(`SELECT id, name, tags, COALESCE(domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, identity, secret, metadata, created_at, updated_at, updated_by, status
|
||||
FROM clients
|
||||
WHERE secret = :secret AND status = %d`, clients.EnabledStatus)
|
||||
switch prefix {
|
||||
case authn.DomainAuth:
|
||||
q += " AND domain_id = :domain_id"
|
||||
case authn.BasicAuth:
|
||||
q += " AND id = :id"
|
||||
default:
|
||||
return clients.Client{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
dbc := DBClient{
|
||||
Secret: key,
|
||||
Domain: id,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbc)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/clients/postgres"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
@@ -358,6 +359,7 @@ func TestClientsRetrieveBySecret(t *testing.T) {
|
||||
Identity: clientIdentity,
|
||||
Secret: testsutil.GenerateUUID(t),
|
||||
},
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
Metadata: clients.Metadata{},
|
||||
Status: clients.EnabledStatus,
|
||||
}
|
||||
@@ -368,17 +370,27 @@ func TestClientsRetrieveBySecret(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
secret string
|
||||
id string
|
||||
response clients.Client
|
||||
prefix authn.AuthPrefix
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve client by secret successfully",
|
||||
desc: "retrieve client by secret with no id",
|
||||
secret: client.Credentials.Secret,
|
||||
response: clients.Client{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve client by client ID and secret successfully",
|
||||
secret: client.Credentials.Secret,
|
||||
id: client.ID,
|
||||
prefix: authn.BasicAuth,
|
||||
response: client,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve client by invalid secret",
|
||||
desc: "retrieve client by client ID invalid secret",
|
||||
secret: "non-existent-secret",
|
||||
response: clients.Client{},
|
||||
err: repoerr.ErrNotFound,
|
||||
@@ -389,10 +401,34 @@ func TestClientsRetrieveBySecret(t *testing.T) {
|
||||
response: clients.Client{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve client by client ID and secret with an invalid ID type",
|
||||
secret: client.Credentials.Secret,
|
||||
id: client.ID,
|
||||
prefix: authn.DomainAuth,
|
||||
response: clients.Client{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve client by domain ID and secret successfully",
|
||||
secret: client.Credentials.Secret,
|
||||
id: client.Domain,
|
||||
prefix: authn.DomainAuth,
|
||||
response: client,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve client by domain ID and secret with an invalid ID type",
|
||||
secret: client.Credentials.Secret,
|
||||
id: client.Domain,
|
||||
prefix: authn.BasicAuth,
|
||||
response: clients.Client{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
res, err := repo.RetrieveBySecret(context.Background(), tc.secret)
|
||||
res, err := repo.RetrieveBySecret(context.Background(), tc.secret, tc.id, tc.prefix)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, res, tc.response, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, res))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
@@ -47,17 +48,20 @@ type service struct {
|
||||
policy policies.Service
|
||||
}
|
||||
|
||||
func (svc service) Authenticate(ctx context.Context, key string) (string, error) {
|
||||
id, err := svc.cache.ID(ctx, key)
|
||||
func (svc service) Authenticate(ctx context.Context, token string) (string, error) {
|
||||
id, err := svc.cache.ID(ctx, token)
|
||||
if err == nil {
|
||||
return id, nil
|
||||
}
|
||||
|
||||
client, err := svc.repo.RetrieveBySecret(ctx, key)
|
||||
prefix, id, key, err := authn.AuthUnpack(token)
|
||||
if err != nil && err != authn.ErrNotEncoded {
|
||||
return "", err
|
||||
}
|
||||
client, err := svc.repo.RetrieveBySecret(ctx, key, id, prefix)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
if err := svc.cache.Save(ctx, key, client.ID); err != nil {
|
||||
if err := svc.cache.Save(ctx, token, client.ID); err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
|
||||
+4
-3
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -59,7 +60,7 @@ func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.Cha
|
||||
|
||||
func (svc *adapterService) Publish(ctx context.Context, key string, msg *messaging.Message) error {
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: key,
|
||||
Token: authn.AuthPack(authn.DomainAuth, msg.GetDomain(), key),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
@@ -89,7 +90,7 @@ func (svc *adapterService) Publish(ctx context.Context, key string, msg *messagi
|
||||
|
||||
func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error {
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: key,
|
||||
Token: authn.AuthPack(authn.DomainAuth, domainID, key),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
@@ -126,7 +127,7 @@ func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID,
|
||||
|
||||
func (svc *adapterService) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error {
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: key,
|
||||
Token: authn.AuthPack(authn.DomainAuth, domainID, key),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
|
||||
@@ -260,7 +260,7 @@ func TestPublish(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: tc.key}).Return(tc.authnRes, tc.authnErr)
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr)
|
||||
domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
DomainId: tc.domainID,
|
||||
|
||||
+1
-1
@@ -130,7 +130,7 @@ func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) e
|
||||
switch {
|
||||
case strings.HasPrefix(string(s.Password), "Client"):
|
||||
secret := strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix)
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: secret})
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)})
|
||||
if err != nil {
|
||||
h.logger.Warn(fmt.Sprintf(logInfoFailedAuthNClient, secret, *topic, err))
|
||||
return mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication)
|
||||
|
||||
@@ -341,7 +341,7 @@ func TestPublish(t *testing.T) {
|
||||
if tc.topic != nil {
|
||||
internalTopic = strings.TrimPrefix(strings.ReplaceAll(*tc.topic, "/", "."), ".m.")
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{ClientSecret: tc.password}).Return(tc.authNRes, tc.authNErr)
|
||||
clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.password)}).Return(tc.authNRes, tc.authNErr)
|
||||
authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", ctx, mock.Anything).Return(tc.authZRes, tc.authZErr)
|
||||
repoCall := publisher.On("Publish", ctx, internalTopic, mock.Anything).Return(tc.publishErr)
|
||||
|
||||
@@ -37,8 +37,7 @@ service ClientsService {
|
||||
|
||||
|
||||
message AuthnReq {
|
||||
string client_id = 1;
|
||||
string client_secret = 2;
|
||||
string token = 1;
|
||||
}
|
||||
|
||||
message AuthnRes {
|
||||
|
||||
+2
-1
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -86,7 +87,7 @@ func (h *handler) AuthConnect(ctx context.Context) error {
|
||||
|
||||
pwd := string(s.Password)
|
||||
|
||||
res, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ClientSecret: pwd})
|
||||
res, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, s.Username, pwd)})
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/mqtt"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -145,11 +146,13 @@ func TestAuthConnect(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
password := ""
|
||||
username := ""
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
password = string(tc.session.Password)
|
||||
username = tc.session.Username
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ClientSecret: password}).Return(tc.authNRes, tc.authNErr)
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, username, password)}).Return(tc.authNRes, tc.authNErr)
|
||||
err := handler.AuthConnect(ctx)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
clientsCall.Unset()
|
||||
|
||||
@@ -5,6 +5,9 @@ package authn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TokenType uint32
|
||||
@@ -48,3 +51,56 @@ type Session struct {
|
||||
type Authentication interface {
|
||||
Authenticate(ctx context.Context, token string) (Session, error)
|
||||
}
|
||||
|
||||
const authSep = ":"
|
||||
|
||||
type AuthPrefix int
|
||||
|
||||
const (
|
||||
Unknown AuthPrefix = iota
|
||||
BasicAuth
|
||||
DomainAuth
|
||||
)
|
||||
|
||||
var authPrefixStrings = [3]string{
|
||||
"Unknown",
|
||||
"Basic",
|
||||
"Domain",
|
||||
}
|
||||
|
||||
// String returns the string representation (e.g., "Basic") of the AuthPrefix.
|
||||
func (a AuthPrefix) String() string {
|
||||
if int(a) < len(authPrefixStrings) {
|
||||
return authPrefixStrings[a]
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
// ErrNotEncoded acts similarly to EOF - it does indicate there is no suffix in
|
||||
// the token, but that does not have to be treated as the error in some cases.
|
||||
// If token is not base64-encoded, the token is returned as a key alongside with the error.
|
||||
var ErrNotEncoded = errors.New("token is not encoded with suffix")
|
||||
|
||||
func AuthUnpack(token string) (AuthPrefix, string, string, error) {
|
||||
var auth AuthPrefix
|
||||
for i, pref := range authPrefixStrings {
|
||||
if strings.HasPrefix(token, pref) {
|
||||
token = token[len(pref):]
|
||||
auth = AuthPrefix(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
payload, err := base64.StdEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
return Unknown, token, "", err
|
||||
}
|
||||
id, key, found := strings.Cut(string(payload), authSep)
|
||||
if !found {
|
||||
return auth, id, key, ErrNotEncoded
|
||||
}
|
||||
return auth, id, key, nil
|
||||
}
|
||||
|
||||
func AuthPack(prefix AuthPrefix, id, key string) string {
|
||||
return prefix.String() + base64.StdEncoding.EncodeToString([]byte(id+":"+key))
|
||||
}
|
||||
|
||||
+3
-5
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -91,13 +92,10 @@ func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID,
|
||||
// authorize checks if the clientKey is authorized to access the channel
|
||||
// and returns the clientID if it is.
|
||||
func (svc *adapterService) authorize(ctx context.Context, clientKey, domainID, chanID string, msgType connections.ConnType) (string, error) {
|
||||
authnReq := &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: clientKey,
|
||||
}
|
||||
if strings.HasPrefix(clientKey, "Client") {
|
||||
authnReq.ClientSecret = extractClientSecret(clientKey)
|
||||
clientKey = extractClientSecret(clientKey)
|
||||
}
|
||||
authnRes, err := svc.clients.Authenticate(ctx, authnReq)
|
||||
authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, clientKey)})
|
||||
if err != nil {
|
||||
return "", errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
+3
-2
@@ -15,6 +15,7 @@ import (
|
||||
chmocks "github.com/absmach/supermq/channels/mocks"
|
||||
climocks "github.com/absmach/supermq/clients/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -190,9 +191,9 @@ func TestSubscribe(t *testing.T) {
|
||||
ClientID: clientID,
|
||||
Handler: c,
|
||||
}
|
||||
authReq := &grpcClientsV1.AuthnReq{ClientSecret: tc.clientKey}
|
||||
authReq := &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, tc.domainID, tc.clientKey)}
|
||||
if strings.HasPrefix(tc.clientKey, "Client") {
|
||||
authReq.ClientSecret = strings.TrimPrefix(tc.clientKey, "Client ")
|
||||
authReq.Token = authn.AuthPack(authn.DomainAuth, tc.domainID, strings.TrimPrefix(tc.clientKey, "Client "))
|
||||
}
|
||||
clientsCall := clients.On("Authenticate", mock.Anything, authReq).Return(tc.authNRes, tc.authNErr)
|
||||
channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{
|
||||
|
||||
@@ -130,7 +130,8 @@ func TestHandshake(t *testing.T) {
|
||||
pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
clients.On("Authenticate", mock.Anything, mock.MatchedBy(func(req *grpcClientsV1.AuthnReq) bool {
|
||||
return req.ClientSecret == clientKey
|
||||
_, _, key, _ := smqauthn.AuthUnpack(req.Token)
|
||||
return key == clientKey
|
||||
})).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil)
|
||||
clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil)
|
||||
authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil)
|
||||
|
||||
+3
-6
@@ -16,6 +16,7 @@ import (
|
||||
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
|
||||
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -194,14 +195,10 @@ func (h *handler) Disconnect(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (h *handler) authAccess(ctx context.Context, token, domainID, chanID string, msgType connections.ConnType) (string, string, error) {
|
||||
authnReq := &grpcClientsV1.AuthnReq{
|
||||
ClientSecret: token,
|
||||
}
|
||||
if strings.HasPrefix(token, "Client") {
|
||||
authnReq.ClientSecret = extractClientSecret(token)
|
||||
token = extractClientSecret(token)
|
||||
}
|
||||
|
||||
authnRes, err := h.clients.Authenticate(ctx, authnReq)
|
||||
authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.DomainAuth, domainID, token)})
|
||||
if err != nil {
|
||||
return "", "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user