MG-2048 - Implement Personal Access Tokens (PATs) (#2492)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2024-12-18 12:33:07 +03:00
committed by GitHub
parent 0d126c3be0
commit e8e17f5530
42 changed files with 5545 additions and 67 deletions
+73 -3
View File
@@ -7,6 +7,7 @@ import (
"context"
"time"
"github.com/absmach/supermq/auth"
grpcapi "github.com/absmach/supermq/auth/api/grpc"
grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1"
"github.com/go-kit/kit/endpoint"
@@ -17,9 +18,11 @@ import (
const authSvcName = "auth.v1.AuthService"
type authGrpcClient struct {
authenticate endpoint.Endpoint
authorize endpoint.Endpoint
timeout time.Duration
authenticate endpoint.Endpoint
authenticatePAT endpoint.Endpoint
authorize endpoint.Endpoint
authorizePAT endpoint.Endpoint
timeout time.Duration
}
var _ grpcAuthV1.AuthServiceClient = (*authGrpcClient)(nil)
@@ -35,6 +38,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth
decodeIdentifyResponse,
grpcAuthV1.AuthNRes{},
).Endpoint(),
authenticatePAT: kitgrpc.NewClient(
conn,
authSvcName,
"AuthenticatePAT",
encodeIdentifyRequest,
decodeIdentifyPATResponse,
grpcAuthV1.AuthNRes{},
).Endpoint(),
authorize: kitgrpc.NewClient(
conn,
authSvcName,
@@ -43,6 +54,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth
decodeAuthorizeResponse,
grpcAuthV1.AuthZRes{},
).Endpoint(),
authorizePAT: kitgrpc.NewClient(
conn,
authSvcName,
"AuthorizePAT",
encodeAuthorizePATRequest,
decodeAuthorizeResponse,
grpcAuthV1.AuthZRes{},
).Endpoint(),
timeout: timeout,
}
}
@@ -69,6 +88,23 @@ func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{}
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil
}
func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.authenticatePAT(ctx, authenticateReq{token: token.GetToken()})
if err != nil {
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
}
ir := res.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID}, nil
}
func decodeIdentifyPATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(*grpcAuthV1.AuthNRes)
return authenticateRes{id: res.GetId(), userID: res.GetUserId()}, nil
}
func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
@@ -109,3 +145,37 @@ func encodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}
Object: req.Object,
}, nil
}
func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.AuthZPatReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.authorizePAT(ctx, authPATReq{
userID: req.GetUserId(),
patID: req.GetPatId(),
platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()),
operation: auth.OperationType(req.GetOperation()),
entityIDs: req.GetEntityIds(),
})
if err != nil {
return &grpcAuthV1.AuthZRes{}, grpcapi.DecodeError(err)
}
ar := res.(authorizeRes)
return &grpcAuthV1.AuthZRes{Authorized: ar.authorized, Id: ar.id}, nil
}
func encodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(authPATReq)
return &grpcAuthV1.AuthZPatReq{
UserId: req.userID,
PatId: req.patID,
PlatformEntityType: uint32(req.platformEntityType),
OptionalDomainId: req.optionalDomainID,
OptionalDomainEntityType: uint32(req.optionalDomainEntityType),
Operation: uint32(req.operation),
EntityIds: req.entityIDs,
}, nil
}
+31
View File
@@ -27,6 +27,22 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
}
}
func authenticatePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authenticateReq)
if err := req.validate(); err != nil {
return authenticateRes{}, err
}
pat, err := svc.IdentifyPAT(ctx, req.token)
if err != nil {
return authenticateRes{}, err
}
return authenticateRes{id: pat.ID, userID: pat.User}, nil
}
}
func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authReq)
@@ -50,3 +66,18 @@ func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
return authorizeRes{authorized: true}, nil
}
}
func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authPATReq)
if err := req.validate(); err != nil {
return authorizeRes{}, err
}
err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.platformEntityType, req.optionalDomainID, req.optionalDomainEntityType, req.operation, req.entityIDs...)
if err != nil {
return authorizeRes{authorized: false}, err
}
return authorizeRes{authorized: true}, nil
}
}
+164 -16
View File
@@ -41,12 +41,15 @@ const (
invalidDuration = 7 * 24 * time.Hour
validToken = "valid"
inValidToken = "invalid"
validPATToken = "valid"
inValidPATToken = "invalid"
validPolicy = "valid"
)
var (
domainID = testsutil.GenerateUUID(&testing.T{})
authAddr = fmt.Sprintf("localhost:%d", port)
clientID = testsutil.GenerateUUID(&testing.T{})
)
func startGRPCServer(svc auth.Service, port int) *grpc.Server {
@@ -63,8 +66,8 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server {
func TestIdentify(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
@@ -96,20 +99,23 @@ func TestIdentify(t *testing.T) {
}
for _, tc := range cases {
svcCall := svc.On("Identify", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestAuthorize(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
@@ -219,12 +225,154 @@ func TestAuthorize(t *testing.T) {
},
}
for _, tc := range cases {
svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err)
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err)
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
})
}
}
func TestIdentifyPAT(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
desc string
token string
idt *grpcAuthV1.AuthNRes
svcErr error
err error
}{
{
desc: "authenticate user with valid user token",
token: validToken,
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: clientID},
err: nil,
},
{
desc: "authenticate user with invalid user token",
token: "invalid",
idt: &grpcAuthV1.AuthNRes{},
svcErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "authenticate user with empty token",
token: "",
idt: &grpcAuthV1.AuthNRes{},
err: apiutil.ErrBearerToken,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("IdentifyPAT", mock.Anything, tc.token).Return(auth.PAT{ID: id, User: clientID, IssuedAt: time.Now()}, tc.svcErr)
idt, err := grpcClient.AuthenticatePAT(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestAuthorizePAT(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
desc string
token string
authRequest *grpcAuthV1.AuthZPatReq
authResponse *grpcAuthV1.AuthZRes
err error
}{
{
desc: "authorize user with authorized token",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
},
{
desc: "authorize user with unauthorized token",
token: inValidPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "authorize user with missing user id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingUserID,
},
{
desc: "authorize user with missing pat id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPATID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svccall := svc.On("AuthorizePAT",
mock.Anything,
tc.authRequest.UserId,
tc.authRequest.PatId,
mock.Anything,
tc.authRequest.OptionalDomainId,
mock.Anything,
mock.Anything,
mock.Anything).Return(tc.err)
ar, err := grpcClient.AuthorizePAT(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
})
}
}
+21
View File
@@ -4,6 +4,7 @@
package auth
import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/apiutil"
)
@@ -49,3 +50,23 @@ func (req authReq) validate() error {
return nil
}
type authPATReq struct {
userID string
patID string
platformEntityType auth.PlatformEntityType
optionalDomainID string
optionalDomainEntityType auth.DomainEntityType
operation auth.OperationType
entityIDs []string
}
func (req authPATReq) validate() error {
if req.userID == "" {
return apiutil.ErrMissingUserID
}
if req.patID == "" {
return apiutil.ErrMissingPATID
}
return nil
}
+50 -2
View File
@@ -16,8 +16,10 @@ var _ grpcAuthV1.AuthServiceServer = (*authGrpcServer)(nil)
type authGrpcServer struct {
grpcAuthV1.UnimplementedAuthServiceServer
authorize kitgrpc.Handler
authenticate kitgrpc.Handler
authorize kitgrpc.Handler
authenticate kitgrpc.Handler
authenticatePAT kitgrpc.Handler
authorizePAT kitgrpc.Handler
}
// NewAuthServer returns new AuthnServiceServer instance.
@@ -34,6 +36,18 @@ func NewAuthServer(svc auth.Service) grpcAuthV1.AuthServiceServer {
decodeAuthenticateRequest,
encodeAuthenticateResponse,
),
authenticatePAT: kitgrpc.NewServer(
(authenticatePATEndpoint(svc)),
decodeAuthenticateRequest,
encodeAuthenticatePATResponse,
),
authorizePAT: kitgrpc.NewServer(
(authorizePATEndpoint(svc)),
decodeAuthorizePATRequest,
encodeAuthorizeResponse,
),
}
}
@@ -45,6 +59,14 @@ func (s *authGrpcServer) Authenticate(ctx context.Context, req *grpcAuthV1.AuthN
return res.(*grpcAuthV1.AuthNRes), nil
}
func (s *authGrpcServer) AuthenticatePAT(ctx context.Context, req *grpcAuthV1.AuthNReq) (*grpcAuthV1.AuthNRes, error) {
_, res, err := s.authenticatePAT.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
}
return res.(*grpcAuthV1.AuthNRes), nil
}
func (s *authGrpcServer) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq) (*grpcAuthV1.AuthZRes, error) {
_, res, err := s.authorize.ServeGRPC(ctx, req)
if err != nil {
@@ -63,6 +85,11 @@ func encodeAuthenticateResponse(_ context.Context, grpcRes interface{}) (interfa
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, DomainId: res.domainID}, nil
}
func encodeAuthenticatePATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID}, nil
}
func decodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcAuthV1.AuthZReq)
return authReq{
@@ -81,3 +108,24 @@ func encodeAuthorizeResponse(_ context.Context, grpcRes interface{}) (interface{
res := grpcRes.(authorizeRes)
return &grpcAuthV1.AuthZRes{Authorized: res.authorized, Id: res.id}, nil
}
func decodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcAuthV1.AuthZPatReq)
return authPATReq{
userID: req.GetUserId(),
patID: req.GetPatId(),
platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()),
operation: auth.OperationType(req.GetOperation()),
entityIDs: req.GetEntityIds(),
}, nil
}
func (s *authGrpcServer) AuthorizePAT(ctx context.Context, req *grpcAuthV1.AuthZPatReq) (*grpcAuthV1.AuthZRes, error) {
_, res, err := s.authorizePAT.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
}
return res.(*grpcAuthV1.AuthZRes), nil
}
+3 -1
View File
@@ -69,12 +69,14 @@ func (tr testRequest) make() (*http.Response, error) {
func newService() (auth.Service, *mocks.KeyRepository) {
krepo := new(mocks.KeyRepository)
pRepo := new(mocks.PATSRepository)
hash := new(mocks.Hasher)
idProvider := uuid.NewMock()
pService := new(policymocks.Service)
pEvaluator := new(policymocks.Evaluator)
t := jwt.New([]byte(secret))
return auth.New(krepo, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo
return auth.New(krepo, pRepo, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo
}
func newServer(svc auth.Service) *httptest.Server {
+187
View File
@@ -0,0 +1,187 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package pats
import (
"context"
"github.com/absmach/supermq/auth"
"github.com/go-kit/kit/endpoint"
)
func createPATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createPatReq)
if err := req.validate(); err != nil {
return nil, err
}
pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration, req.Scope)
if err != nil {
return nil, err
}
return createPatRes{pat}, nil
}
}
func retrievePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(retrievePatReq)
if err := req.validate(); err != nil {
return nil, err
}
pat, err := svc.RetrievePAT(ctx, req.token, req.id)
if err != nil {
return nil, err
}
return retrievePatRes{pat}, nil
}
}
func updatePATNameEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(updatePatNameReq)
if err := req.validate(); err != nil {
return nil, err
}
pat, err := svc.UpdatePATName(ctx, req.token, req.id, req.Name)
if err != nil {
return nil, err
}
return updatePatNameRes{pat}, nil
}
}
func updatePATDescriptionEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(updatePatDescriptionReq)
if err := req.validate(); err != nil {
return nil, err
}
pat, err := svc.UpdatePATDescription(ctx, req.token, req.id, req.Description)
if err != nil {
return nil, err
}
return updatePatDescriptionRes{pat}, nil
}
}
func listPATSEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listPatsReq)
if err := req.validate(); err != nil {
return nil, err
}
pm := auth.PATSPageMeta{
Limit: req.limit,
Offset: req.offset,
}
patsPage, err := svc.ListPATS(ctx, req.token, pm)
if err != nil {
return nil, err
}
return listPatsRes{patsPage}, nil
}
}
func deletePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(deletePatReq)
if err := req.validate(); err != nil {
return nil, err
}
if err := svc.DeletePAT(ctx, req.token, req.id); err != nil {
return nil, err
}
return deletePatRes{}, nil
}
}
func resetPATSecretEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(resetPatSecretReq)
if err := req.validate(); err != nil {
return nil, err
}
pat, err := svc.ResetPATSecret(ctx, req.token, req.id, req.Duration)
if err != nil {
return nil, err
}
return resetPatSecretRes{pat}, nil
}
}
func revokePATSecretEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(revokePatSecretReq)
if err := req.validate(); err != nil {
return nil, err
}
if err := svc.RevokePATSecret(ctx, req.token, req.id); err != nil {
return nil, err
}
return revokePatSecretRes{}, nil
}
}
func addPATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(addPatScopeEntryReq)
if err := req.validate(); err != nil {
return nil, err
}
scope, err := svc.AddPATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...)
if err != nil {
return nil, err
}
return addPatScopeEntryRes{scope}, nil
}
}
func removePATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(removePatScopeEntryReq)
if err := req.validate(); err != nil {
return nil, err
}
scope, err := svc.RemovePATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...)
if err != nil {
return nil, err
}
return removePatScopeEntryRes{scope}, nil
}
}
func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(clearAllScopeEntryReq)
if err := req.validate(); err != nil {
return nil, err
}
if err := svc.ClearPATAllScopeEntry(ctx, req.token, req.id); err != nil {
return nil, err
}
return clearAllScopeEntryRes{}, nil
}
}
+303
View File
@@ -0,0 +1,303 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package pats
import (
"encoding/json"
"strings"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/apiutil"
)
type createPatReq struct {
token string
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
Scope auth.Scope `json:"scope,omitempty"`
}
func (cpr *createPatReq) UnmarshalJSON(data []byte) error {
var temp struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Duration string `json:"duration,omitempty"`
Scope auth.Scope `json:"scope,omitempty"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
duration, err := time.ParseDuration(temp.Duration)
if err != nil {
return err
}
cpr.Name = temp.Name
cpr.Description = temp.Description
cpr.Duration = duration
cpr.Scope = temp.Scope
return nil
}
func (req createPatReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if strings.TrimSpace(req.Name) == "" {
return apiutil.ErrMissingName
}
return nil
}
type retrievePatReq struct {
token string
id string
}
func (req retrievePatReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type updatePatNameReq struct {
token string
id string
Name string `json:"name,omitempty"`
}
func (req updatePatNameReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
if strings.TrimSpace(req.Name) == "" {
return apiutil.ErrMissingName
}
return nil
}
type updatePatDescriptionReq struct {
token string
id string
Description string `json:"description,omitempty"`
}
func (req updatePatDescriptionReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
if strings.TrimSpace(req.Description) == "" {
return apiutil.ErrMissingDescription
}
return nil
}
type listPatsReq struct {
token string
offset uint64
limit uint64
}
func (req listPatsReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
return nil
}
type deletePatReq struct {
token string
id string
}
func (req deletePatReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type resetPatSecretReq struct {
token string
id string
Duration time.Duration `json:"duration,omitempty"`
}
func (rspr *resetPatSecretReq) UnmarshalJSON(data []byte) error {
var temp struct {
Duration string `json:"duration,omitempty"`
}
err := json.Unmarshal(data, &temp)
if err != nil {
return err
}
rspr.Duration, err = time.ParseDuration(temp.Duration)
if err != nil {
return err
}
return nil
}
func (req resetPatSecretReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type revokePatSecretReq struct {
token string
id string
}
func (req revokePatSecretReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type addPatScopeEntryReq struct {
token string
id string
PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"`
Operation auth.OperationType `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
}
func (apser *addPatScopeEntryReq) UnmarshalJSON(data []byte) error {
var temp struct {
PlatformEntityType string `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"`
Operation string `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType)
if err != nil {
return err
}
odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType)
if err != nil {
return err
}
op, err := auth.ParseOperationType(temp.Operation)
if err != nil {
return err
}
apser.PlatformEntityType = pet
apser.OptionalDomainID = temp.OptionalDomainID
apser.OptionalDomainEntityType = odt
apser.Operation = op
apser.EntityIDs = temp.EntityIDs
return nil
}
func (req addPatScopeEntryReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type removePatScopeEntryReq struct {
token string
id string
PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"`
Operation auth.OperationType `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
}
func (rpser *removePatScopeEntryReq) UnmarshalJSON(data []byte) error {
var temp struct {
PlatformEntityType string `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"`
Operation string `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType)
if err != nil {
return err
}
odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType)
if err != nil {
return err
}
op, err := auth.ParseOperationType(temp.Operation)
if err != nil {
return err
}
rpser.PlatformEntityType = pet
rpser.OptionalDomainID = temp.OptionalDomainID
rpser.OptionalDomainEntityType = odt
rpser.Operation = op
rpser.EntityIDs = temp.EntityIDs
return nil
}
func (req removePatScopeEntryReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type clearAllScopeEntryReq struct {
token string
id string
}
func (req clearAllScopeEntryReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
+194
View File
@@ -0,0 +1,194 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package pats
import (
"net/http"
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
)
var (
_ supermq.Response = (*createPatRes)(nil)
_ supermq.Response = (*retrievePatRes)(nil)
_ supermq.Response = (*updatePatNameRes)(nil)
_ supermq.Response = (*updatePatDescriptionRes)(nil)
_ supermq.Response = (*deletePatRes)(nil)
_ supermq.Response = (*resetPatSecretRes)(nil)
_ supermq.Response = (*revokePatSecretRes)(nil)
_ supermq.Response = (*addPatScopeEntryRes)(nil)
_ supermq.Response = (*removePatScopeEntryRes)(nil)
_ supermq.Response = (*clearAllScopeEntryRes)(nil)
)
type createPatRes struct {
auth.PAT
}
func (res createPatRes) Code() int {
return http.StatusCreated
}
func (res createPatRes) Headers() map[string]string {
return map[string]string{}
}
func (res createPatRes) Empty() bool {
return false
}
type retrievePatRes struct {
auth.PAT
}
func (res retrievePatRes) Code() int {
return http.StatusOK
}
func (res retrievePatRes) Headers() map[string]string {
return map[string]string{}
}
func (res retrievePatRes) Empty() bool {
return false
}
type updatePatNameRes struct {
auth.PAT
}
func (res updatePatNameRes) Code() int {
return http.StatusAccepted
}
func (res updatePatNameRes) Headers() map[string]string {
return map[string]string{}
}
func (res updatePatNameRes) Empty() bool {
return false
}
type updatePatDescriptionRes struct {
auth.PAT
}
func (res updatePatDescriptionRes) Code() int {
return http.StatusAccepted
}
func (res updatePatDescriptionRes) Headers() map[string]string {
return map[string]string{}
}
func (res updatePatDescriptionRes) Empty() bool {
return false
}
type listPatsRes struct {
auth.PATSPage
}
func (res listPatsRes) Code() int {
return http.StatusOK
}
func (res listPatsRes) Headers() map[string]string {
return map[string]string{}
}
func (res listPatsRes) Empty() bool {
return false
}
type deletePatRes struct{}
func (res deletePatRes) Code() int {
return http.StatusNoContent
}
func (res deletePatRes) Headers() map[string]string {
return map[string]string{}
}
func (res deletePatRes) Empty() bool {
return true
}
type resetPatSecretRes struct {
auth.PAT
}
func (res resetPatSecretRes) Code() int {
return http.StatusOK
}
func (res resetPatSecretRes) Headers() map[string]string {
return map[string]string{}
}
func (res resetPatSecretRes) Empty() bool {
return false
}
type revokePatSecretRes struct{}
func (res revokePatSecretRes) Code() int {
return http.StatusNoContent
}
func (res revokePatSecretRes) Headers() map[string]string {
return map[string]string{}
}
func (res revokePatSecretRes) Empty() bool {
return true
}
type addPatScopeEntryRes struct {
auth.Scope
}
func (res addPatScopeEntryRes) Code() int {
return http.StatusOK
}
func (res addPatScopeEntryRes) Headers() map[string]string {
return map[string]string{}
}
func (res addPatScopeEntryRes) Empty() bool {
return false
}
type removePatScopeEntryRes struct {
auth.Scope
}
func (res removePatScopeEntryRes) Code() int {
return http.StatusOK
}
func (res removePatScopeEntryRes) Headers() map[string]string {
return map[string]string{}
}
func (res removePatScopeEntryRes) Empty() bool {
return false
}
type clearAllScopeEntryRes struct{}
func (res clearAllScopeEntryRes) Code() int {
return http.StatusOK
}
func (res clearAllScopeEntryRes) Headers() map[string]string {
return map[string]string{}
}
func (res clearAllScopeEntryRes) Empty() bool {
return true
}
+300
View File
@@ -0,0 +1,300 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package pats
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/internal/api"
"github.com/absmach/supermq/pkg/apiutil"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
)
const (
contentType = "application/json"
defInterval = "30d"
patPrefix = "pat_"
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
mux.Route("/pats", func(r chi.Router) {
r.Post("/", kithttp.NewServer(
createPATEndpoint(svc),
decodeCreatePATRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Get("/", kithttp.NewServer(
listPATSEndpoint(svc),
decodeListPATSRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", kithttp.NewServer(
retrievePATEndpoint(svc),
decodeRetrievePATRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Patch("/name", kithttp.NewServer(
updatePATNameEndpoint(svc),
decodeUpdatePATNameRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Patch("/description", kithttp.NewServer(
updatePATDescriptionEndpoint(svc),
decodeUpdatePATDescriptionRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Delete("/", kithttp.NewServer(
deletePATEndpoint(svc),
decodeDeletePATRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Route("/secret", func(r chi.Router) {
r.Patch("/reset", kithttp.NewServer(
resetPATSecretEndpoint(svc),
decodeResetPATSecretRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Patch("/revoke", kithttp.NewServer(
revokePATSecretEndpoint(svc),
decodeRevokePATSecretRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
})
r.Route("/scope", func(r chi.Router) {
r.Patch("/add", kithttp.NewServer(
addPATScopeEntryEndpoint(svc),
decodeAddPATScopeEntryRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Patch("/remove", kithttp.NewServer(
removePATScopeEntryEndpoint(svc),
decodeRemovePATScopeEntryRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Delete("/", kithttp.NewServer(
clearPATAllScopeEntryEndpoint(svc),
decodeClearPATAllScopeEntryRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
})
})
})
return mux
}
func decodeCreatePATRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := createPatReq{token: token}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
func decodeRetrievePATRequest(_ context.Context, r *http.Request) (interface{}, error) {
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := retrievePatReq{
token: token,
id: chi.URLParam(r, "id"),
}
return req, nil
}
func decodeUpdatePATNameRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := updatePatNameReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeUpdatePATDescriptionRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := updatePatDescriptionReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeListPATSRequest(_ context.Context, r *http.Request) (interface{}, error) {
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := listPatsReq{
token: token,
limit: l,
offset: o,
}
return req, nil
}
func decodeDeletePATRequest(_ context.Context, r *http.Request) (interface{}, error) {
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
return deletePatReq{
token: token,
id: chi.URLParam(r, "id"),
}, nil
}
func decodeResetPATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := resetPatSecretReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeRevokePATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) {
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
return revokePatSecretReq{
token: token,
id: chi.URLParam(r, "id"),
}, nil
}
func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := addPatScopeEntryReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := removePatScopeEntryReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
return clearAllScopeEntryReq{
token: token,
id: chi.URLParam(r, "id"),
}, nil
}
+2
View File
@@ -9,6 +9,7 @@ import (
"github.com/absmach/supermq"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/api/http/keys"
"github.com/absmach/supermq/auth/api/http/pats"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
@@ -18,6 +19,7 @@ func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string) http.
mux := chi.NewRouter()
mux = keys.MakeHandler(svc, mux, logger)
mux = pats.MakeHandler(svc, mux, logger)
mux.Get("/health", supermq.Health("auth", instanceID))
mux.Handle("/metrics", promhttp.Handler())
+250
View File
@@ -124,3 +124,253 @@ func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy)
}(time.Now())
return lm.svc.Authorize(ctx, pr)
}
func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("name", name),
slog.String("description", description),
slog.String("pat_duration", duration.String()),
slog.String("scope", scope.String()),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Create PAT failed", args...)
return
}
lm.logger.Info("Create PAT completed successfully", args...)
}(time.Now())
return lm.svc.CreatePAT(ctx, token, name, description, duration, scope)
}
func (lm *loggingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("name", name),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Update PAT name failed", args...)
return
}
lm.logger.Info("Update PAT name completed successfully", args...)
}(time.Now())
return lm.svc.UpdatePATName(ctx, token, patID, name)
}
func (lm *loggingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("description", description),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Update PAT description failed", args...)
return
}
lm.logger.Info("Update PAT description completed successfully", args...)
}(time.Now())
return lm.svc.UpdatePATDescription(ctx, token, patID, description)
}
func (lm *loggingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Retrieve PAT failed", args...)
return
}
lm.logger.Info("Retrieve PAT completed successfully", args...)
}(time.Now())
return lm.svc.RetrievePAT(ctx, token, patID)
}
func (lm *loggingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (pp auth.PATSPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List PATS failed", args...)
return
}
lm.logger.Info("List PATS completed successfully", args...)
}(time.Now())
return lm.svc.ListPATS(ctx, token, pm)
}
func (lm *loggingMiddleware) DeletePAT(ctx context.Context, token, patID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Delete PAT failed", args...)
return
}
lm.logger.Info("Delete PAT completed successfully", args...)
}(time.Now())
return lm.svc.DeletePAT(ctx, token, patID)
}
func (lm *loggingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("pat_duration", duration.String()),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Reset PAT secret failed", args...)
return
}
lm.logger.Info("Reset PAT secret completed successfully", args...)
}(time.Now())
return lm.svc.ResetPATSecret(ctx, token, patID, duration)
}
func (lm *loggingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Revoke PAT secret failed", args...)
return
}
lm.logger.Info("Revoke PAT secret completed successfully", args...)
}(time.Now())
return lm.svc.RevokePATSecret(ctx, token, patID)
}
func (lm *loggingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Add entry to PAT scope failed", args...)
return
}
lm.logger.Info("Add entry to PAT scope completed successfully", args...)
}(time.Now())
return lm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Remove entry from PAT scope failed", args...)
return
}
lm.logger.Info("Remove entry from PAT scope completed successfully", args...)
}(time.Now())
return lm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Clear all entry from PAT scope failed", args...)
return
}
lm.logger.Info("Clear all entry from PAT scope completed successfully", args...)
}(time.Now())
return lm.svc.ClearPATAllScopeEntry(ctx, token, patID)
}
func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Identify PAT failed", args...)
return
}
lm.logger.Info("Identify PAT completed successfully", args...)
}(time.Now())
return lm.svc.IdentifyPAT(ctx, paToken)
}
func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Authorize PAT failed complete successfully", args...)
return
}
lm.logger.Info("Authorize PAT completed successfully", args...)
}(time.Now())
return lm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (lm *loggingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Check PAT failed complete successfully", args...)
return
}
lm.logger.Info("Check PAT completed successfully", args...)
}(time.Now())
return lm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
+112
View File
@@ -74,3 +74,115 @@ func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy)
}(time.Now())
return ms.svc.Authorize(ctx, pr)
}
func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "create_pat").Add(1)
ms.latency.With("method", "create_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.CreatePAT(ctx, token, name, description, duration, scope)
}
func (ms *metricsMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "update_pat_name").Add(1)
ms.latency.With("method", "update_pat_name").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.UpdatePATName(ctx, token, patID, name)
}
func (ms *metricsMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "update_pat_description").Add(1)
ms.latency.With("method", "update_pat_description").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.UpdatePATDescription(ctx, token, patID, description)
}
func (ms *metricsMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "retrieve_pat").Add(1)
ms.latency.With("method", "retrieve_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RetrievePAT(ctx, token, patID)
}
func (ms *metricsMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_pats").Add(1)
ms.latency.With("method", "list_pats").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListPATS(ctx, token, pm)
}
func (ms *metricsMiddleware) DeletePAT(ctx context.Context, token, patID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "delete_pat").Add(1)
ms.latency.With("method", "delete_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.DeletePAT(ctx, token, patID)
}
func (ms *metricsMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "reset_pat_secret").Add(1)
ms.latency.With("method", "reset_pat_secret").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ResetPATSecret(ctx, token, patID, duration)
}
func (ms *metricsMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "revoke_pat_secret").Add(1)
ms.latency.With("method", "revoke_pat_secret").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RevokePATSecret(ctx, token, patID)
}
func (ms *metricsMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
defer func(begin time.Time) {
ms.counter.With("method", "add_pat_scope_entry").Add(1)
ms.latency.With("method", "add_pat_scope_entry").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (ms *metricsMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
defer func(begin time.Time) {
ms.counter.With("method", "remove_pat_scope_entry").Add(1)
ms.latency.With("method", "remove_pat_scope_entry").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (ms *metricsMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "clear_pat_all_scope_entry").Add(1)
ms.latency.With("method", "clear_pat_all_scope_entry").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ClearPATAllScopeEntry(ctx, token, patID)
}
func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "identify_pat").Add(1)
ms.latency.With("method", "identify_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.IdentifyPAT(ctx, paToken)
}
func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
defer func(begin time.Time) {
ms.counter.With("method", "authorize_pat").Add(1)
ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (ms *metricsMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
defer func(begin time.Time) {
ms.counter.With("method", "check_pat").Add(1)
ms.latency.With("method", "check_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package bolt contains PAT repository implementations using
// bolt as the underlying database.
package bolt
+21
View File
@@ -0,0 +1,21 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package bolt contains PAT repository implementations using
// bolt as the underlying database.
package bolt
import (
"github.com/absmach/supermq/pkg/errors"
bolt "go.etcd.io/bbolt"
)
var errInit = errors.New("failed to initialize BoltDB")
func Init(tx *bolt.Tx, bucket string) error {
_, err := tx.CreateBucketIfNotExists([]byte(bucket))
if err != nil {
return errors.Wrap(errInit, err)
}
return nil
}
+812
View File
@@ -0,0 +1,812 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bolt
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
bolt "go.etcd.io/bbolt"
)
const (
idKey = "id"
userKey = "user"
nameKey = "name"
descriptionKey = "description"
secretKey = "secret_key"
scopeKey = "scope"
issuedAtKey = "issued_at"
expiresAtKey = "expires_at"
updatedAtKey = "updated_at"
lastUsedAtKey = "last_used_at"
revokedKey = "revoked"
revokedAtKey = "revoked_at"
platformEntitiesKey = "platform_entities"
patKey = "pat"
keySeparator = ":"
anyID = "*"
)
var (
activateValue = []byte{0x00}
revokedValue = []byte{0x01}
entityValue = []byte{0x02}
anyIDValue = []byte{0x03}
selectedIDsValue = []byte{0x04}
)
type patRepo struct {
db *bolt.DB
bucketName string
}
// NewPATSRepository instantiates a bolt
// implementation of PAT repository.
func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository {
return &patRepo{
db: db,
bucketName: bucketName,
}
}
func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error {
idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID)
kv, err := patToKeyValue(pat)
if err != nil {
return err
}
return pr.db.Update(func(tx *bolt.Tx) error {
rootBucket, err := pr.retrieveRootBucket(tx)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
b, err := pr.createUserBucket(rootBucket, pat.User)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
for key, value := range kv {
fullKey := []byte(pat.ID + keySeparator + key)
if err := b.Put(fullKey, value); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
}
if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
return nil
})
}
func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return err
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) {
revoked := true
expired := false
keySecret := patID + keySeparator + secretKey
keyRevoked := patID + keySeparator + revokedKey
keyExpiresAt := patID + keySeparator + expiresAtKey
var secretHash string
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return err
}
secretHash = string(b.Get([]byte(keySecret)))
revoked = bytesToBoolean(b.Get([]byte(keyRevoked)))
expiresAt := bytesToTime(b.Get([]byte(keyExpiresAt)))
expired = time.Now().After(expiresAt)
return nil
}); err != nil {
return "", true, true, err
}
return secretHash, revoked, expired, nil
}
func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) {
return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name))
}
func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) {
return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description))
}
func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
prefix := []byte(userID + keySeparator + patKey + keySeparator)
patIDs := []string{}
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveRootBucket(tx)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
if v != nil {
patIDs = append(patIDs, string(v))
}
}
return nil
}); err != nil {
return auth.PATSPage{}, err
}
total := len(patIDs)
var pats []auth.PAT
patsPage := auth.PATSPage{
Total: uint64(total),
Limit: pm.Limit,
Offset: pm.Offset,
PATS: pats,
}
if int(pm.Offset) >= total {
return patsPage, nil
}
aLimit := pm.Limit
if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit {
aLimit = uint64(rLimit)
}
for i := pm.Offset; i < pm.Offset+aLimit; i++ {
if int(i) < total {
pat, err := pr.Retrieve(ctx, userID, patIDs[i])
if err != nil {
return patsPage, err
}
patsPage.PATS = append(patsPage.PATS, pat)
}
}
return patsPage, nil
}
func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error {
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error {
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error {
prefix := []byte(patID + keySeparator)
idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity)
if err != nil {
return err
}
c := b.Cursor()
for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() {
if err := b.Delete(k); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
rb, err := pr.retrieveRootBucket(tx)
if err != nil {
return err
}
if err := rb.Delete(idxKey); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
prefix := []byte(patID + keySeparator + scopeKey)
rKV := make(map[string][]byte)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity)
if err != nil {
return err
}
kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return err
}
for key, value := range kv {
fullKey := []byte(patID + keySeparator + key)
if err := b.Put(fullKey, value); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
rKV[string(k)] = v
}
return nil
}); err != nil {
return auth.Scope{}, err
}
return parseKeyValueToScope(rKV)
}
func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
if len(entityIDs) == 0 {
return auth.Scope{}, repoerr.ErrMalformedEntity
}
prefix := []byte(patID + keySeparator + scopeKey)
rKV := make(map[string][]byte)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity)
if err != nil {
return err
}
kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return err
}
for key := range kv {
fullKey := []byte(patID + keySeparator + key)
if err := b.Delete(fullKey); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
rKV[string(k)] = v
}
return nil
}); err != nil {
return auth.Scope{}, err
}
return parseKeyValueToScope(rKV)
}
func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
return pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
rootKey := patID + keySeparator + srootKey
if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) {
return nil
}
for _, entity := range entityIDs {
value := b.Get([]byte(rootKey + keySeparator + entity))
if !bytes.Equal(value, entityValue) {
return repoerr.ErrNotFound
}
}
return nil
})
}
func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error {
return nil
}
func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+key), value); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) {
userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID))
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err))
}
return userBucket, nil
}
func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) {
rootBucket, err := pr.retrieveRootBucket(tx)
if err != nil {
return nil, errors.Wrap(wrap, err)
}
vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID))
if vPatID == nil {
return nil, repoerr.ErrNotFound
}
userBucket := rootBucket.Bucket([]byte(userID))
if userBucket == nil {
return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID))
}
return userBucket, nil
}
func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) {
rootBucket := tx.Bucket([]byte(pr.bucketName))
if rootBucket == nil {
return nil, fmt.Errorf("bucket %s not found", pr.bucketName)
}
return rootBucket, nil
}
func patToKeyValue(pat auth.PAT) (map[string][]byte, error) {
kv := map[string][]byte{
idKey: []byte(pat.ID),
userKey: []byte(pat.User),
nameKey: []byte(pat.Name),
descriptionKey: []byte(pat.Description),
secretKey: []byte(pat.Secret),
issuedAtKey: timeToBytes(pat.IssuedAt),
expiresAtKey: timeToBytes(pat.ExpiresAt),
updatedAtKey: timeToBytes(pat.UpdatedAt),
lastUsedAtKey: timeToBytes(pat.LastUsedAt),
revokedKey: booleanToBytes(pat.Revoked),
revokedAtKey: timeToBytes(pat.RevokedAt),
}
scopeKV, err := scopeToKeyValue(pat.Scope)
if err != nil {
return nil, err
}
for k, v := range scopeKV {
kv[k] = v
}
return kv, nil
}
func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) {
kv := map[string][]byte{}
for opType, scopeValue := range scope.Users {
tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for opType, scopeValue := range scope.Dashboard {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDashBoardScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for opType, scopeValue := range scope.Messaging {
tempKV, err := scopeEntryToKeyValue(auth.PlatformMesagingScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for domainID, domainScope := range scope.Domains {
for opType, scopeValue := range domainScope.DomainManagement {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...)
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, err)
}
for k, v := range tempKV {
kv[k] = v
}
}
for entityType, scope := range domainScope.Entities {
for opType, scopeValue := range scope {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...)
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, err)
}
for k, v := range tempKV {
kv[k] = v
}
}
}
}
return kv, nil
}
func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) {
if len(entityIDs) == 0 {
return nil, repoerr.ErrMalformedEntity
}
rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
if err != nil {
return nil, err
}
if len(entityIDs) == 1 && entityIDs[0] == anyID {
return map[string][]byte{rootKey: anyIDValue}, nil
}
kv := map[string][]byte{rootKey: selectedIDsValue}
for _, entryID := range entityIDs {
if entryID == anyID {
return nil, repoerr.ErrMalformedEntity
}
kv[rootKey+keySeparator+entryID] = entityValue
}
return kv, nil
}
func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) {
op, err := operation.ValidString()
if err != nil {
return "", errors.Wrap(repoerr.ErrMalformedEntity, err)
}
var rootKey strings.Builder
rootKey.WriteString(scopeKey)
rootKey.WriteString(keySeparator)
rootKey.WriteString(platformEntityType.String())
rootKey.WriteString(keySeparator)
switch platformEntityType {
case auth.PlatformUsersScope:
rootKey.WriteString(op)
case auth.PlatformDashBoardScope:
rootKey.WriteString(op)
case auth.PlatformMesagingScope:
rootKey.WriteString(op)
case auth.PlatformDomainsScope:
if optionalDomainID == "" {
return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String())
}
odet, err := optionalDomainEntityType.ValidString()
if err != nil {
return "", errors.Wrap(repoerr.ErrMalformedEntity, err)
}
rootKey.WriteString(optionalDomainID)
rootKey.WriteString(keySeparator)
rootKey.WriteString(odet)
rootKey.WriteString(keySeparator)
rootKey.WriteString(op)
default:
return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String()))
}
return rootKey.String(), nil
}
func keyValueToBasicPAT(kv map[string][]byte) auth.PAT {
var pat auth.PAT
for k, v := range kv {
switch {
case strings.HasSuffix(k, keySeparator+idKey):
pat.ID = string(v)
case strings.HasSuffix(k, keySeparator+userKey):
pat.User = string(v)
case strings.HasSuffix(k, keySeparator+nameKey):
pat.Name = string(v)
case strings.HasSuffix(k, keySeparator+descriptionKey):
pat.Description = string(v)
case strings.HasSuffix(k, keySeparator+issuedAtKey):
pat.IssuedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+expiresAtKey):
pat.ExpiresAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+updatedAtKey):
pat.UpdatedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+lastUsedAtKey):
pat.LastUsedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+revokedKey):
pat.Revoked = bytesToBoolean(v)
case strings.HasSuffix(k, keySeparator+revokedAtKey):
pat.RevokedAt = bytesToTime(v)
}
}
return pat
}
func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) {
pat := keyValueToBasicPAT(kv)
scope, err := parseKeyValueToScope(kv)
if err != nil {
return auth.PAT{}, err
}
pat.Scope = scope
return pat, nil
}
func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) {
scope := auth.Scope{
Domains: make(map[string]auth.DomainScope),
}
for key, value := range kv {
if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 {
keyParts := strings.Split(key, keySeparator)
platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2])
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
switch platformEntityType {
case auth.PlatformUsersScope:
scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformDashBoardScope:
scope.Dashboard, err = parseOperation(platformEntityType, scope.Dashboard, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformMesagingScope:
scope.Messaging, err = parseOperation(platformEntityType, scope.Messaging, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformDomainsScope:
if len(keyParts) < 6 {
return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key)
}
domainID := keyParts[3]
if scope.Domains == nil {
scope.Domains = make(map[string]auth.DomainScope)
}
if _, ok := scope.Domains[domainID]; !ok {
scope.Domains[domainID] = auth.DomainScope{}
}
domainScope := scope.Domains[domainID]
entityType := keyParts[4]
switch entityType {
case auth.DomainManagementScope.String():
domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
default:
etype, err := auth.ParseDomainEntityType(entityType)
if err != nil {
return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err)
}
if domainScope.Entities == nil {
domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope)
}
if _, ok := domainScope.Entities[etype]; !ok {
domainScope.Entities[etype] = auth.OperationScope{}
}
entityOperationScope := domainScope.Entities[etype]
entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
domainScope.Entities[etype] = entityOperationScope
}
scope.Domains[domainID] = domainScope
default:
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()))
}
}
}
return scope, nil
}
func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) {
if opScope == nil {
opScope = make(map[auth.OperationType]auth.ScopeValue)
}
if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil {
return auth.OperationScope{}, err
}
switch string(value) {
case string(entityValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
entityID := keyParts[len(keyParts)-1]
if _, oValueExists := opScope[opType]; !oValueExists {
opScope[opType] = &auth.SelectedIDs{}
}
oValue := opScope[opType]
if err := oValue.AddValues(entityID); err != nil {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err)
}
opScope[opType] = oValue
case string(anyIDValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil {
if _, ok := oValue.(*auth.AnyIDs); !ok {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key)
}
}
opScope[opType] = &auth.AnyIDs{}
case string(selectedIDsValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
oValue, oValueExists := opScope[opType]
if oValueExists && oValue != nil {
if _, ok := oValue.(*auth.SelectedIDs); !ok {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key)
}
}
if !oValueExists {
opScope[opType] = &auth.SelectedIDs{}
}
default:
return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value)
}
return opScope, nil
}
func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error {
expectedKeyPartsLength := 0
switch string(value) {
case string(entityValue):
switch platformEntityType {
case auth.PlatformDomainsScope:
expectedKeyPartsLength = 7
case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope:
expectedKeyPartsLength = 5
default:
return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())
}
case string(selectedIDsValue), string(anyIDValue):
switch platformEntityType {
case auth.PlatformDomainsScope:
expectedKeyPartsLength = 6
case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope:
expectedKeyPartsLength = 4
default:
return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())
}
default:
return fmt.Errorf("key %s have invalid value %v", key, value)
}
if len(keyParts) != expectedKeyPartsLength {
return fmt.Errorf("invalid scope key format: %s", key)
}
return nil
}
func timeToBytes(t time.Time) []byte {
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix()))
return timeBytes
}
func bytesToTime(b []byte) time.Time {
timeAtSeconds := binary.BigEndian.Uint64(b)
return time.Unix(int64(timeAtSeconds), 0)
}
func booleanToBytes(b bool) []byte {
if b {
return []byte{1}
}
return []byte{0}
}
func bytesToBoolean(b []byte) bool {
if len(b) > 1 || b[0] != activateValue[0] {
return true
}
return false
}
+17
View File
@@ -0,0 +1,17 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth
// Hasher specifies an API for generating hashes of an arbitrary textual
// content.
//
//go:generate mockery --name Hasher --output=./mocks --filename hasher.go --quiet --note "Copyright (c) Abstract Machines"
type Hasher interface {
// Hash generates the hashed string from plain-text.
Hash(string) (string, error)
// Compare compares plain-text version to the hashed one. An error should
// indicate failed comparison.
Compare(string, string) error
}
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package hasher contains the domain concept definitions needed to
// support Magistrala users password hasher sub-service functionality.
package hasher
+86
View File
@@ -0,0 +1,86 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package hasher
import (
"encoding/base64"
"fmt"
"math/rand"
"strings"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
"golang.org/x/crypto/scrypt"
)
var (
errHashToken = errors.New("failed to generate hash for token")
errHashCompare = errors.New("failed to generate hash for given compare string")
errToken = errors.New("given token and hash are not same")
errSalt = errors.New("failed to generate salt")
errInvalidHashStore = errors.New("invalid stored hash format")
errDecode = errors.New("failed to decode")
)
var _ auth.Hasher = (*bcryptHasher)(nil)
type bcryptHasher struct{}
// New instantiates a bcrypt-based hasher implementation.
func New() auth.Hasher {
return &bcryptHasher{}
}
func (bh *bcryptHasher) Hash(token string) (string, error) {
salt, err := generateSalt(25)
if err != nil {
return "", err
}
// N is kept 16384 to make faster and added large salt, since PAT will be access by automation scripts in high frequency.
hash, err := scrypt.Key([]byte(token), salt, 16384, 8, 1, 32)
if err != nil {
return "", errors.Wrap(errHashToken, err)
}
return fmt.Sprintf("%s.%s", base64.StdEncoding.EncodeToString(hash), base64.StdEncoding.EncodeToString(salt)), nil
}
func (bh *bcryptHasher) Compare(plain, hashed string) error {
parts := strings.Split(hashed, ".")
if len(parts) != 2 {
return errInvalidHashStore
}
actHash, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return errors.Wrap(errDecode, err)
}
salt, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return errors.Wrap(errDecode, err)
}
derivedHash, err := scrypt.Key([]byte(plain), salt, 16384, 8, 1, 32)
if err != nil {
return errors.Wrap(errHashCompare, err)
}
if string(derivedHash) == string(actHash) {
return nil
}
return errToken
}
func generateSalt(length int) ([]byte, error) {
rand.New(rand.NewSource(time.Now().UnixNano()))
salt := make([]byte, length)
_, err := rand.Read(salt)
if err != nil {
return nil, errors.Wrap(errSalt, err)
}
return salt, nil
}
+4
View File
@@ -30,6 +30,8 @@ const (
RecoveryKey
// APIKey enables the one to act on behalf of the user.
APIKey
// PersonalAccessToken represents token generated by user for automation.
PersonalAccessToken
// InvitationKey is a key for inviting new users.
InvitationKey
)
@@ -44,6 +46,8 @@ func (kt KeyType) String() string {
return "recovery"
case APIKey:
return "API"
case PersonalAccessToken:
return "pat"
default:
return "unknown"
}
+72
View File
@@ -0,0 +1,72 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import mock "github.com/stretchr/testify/mock"
// Hasher is an autogenerated mock type for the Hasher type
type Hasher struct {
mock.Mock
}
// Compare provides a mock function with given fields: _a0, _a1
func (_m *Hasher) Compare(_a0 string, _a1 string) error {
ret := _m.Called(_a0, _a1)
if len(ret) == 0 {
panic("no return value specified for Compare")
}
var r0 error
if rf, ok := ret.Get(0).(func(string, string) error); ok {
r0 = rf(_a0, _a1)
} else {
r0 = ret.Error(0)
}
return r0
}
// Hash provides a mock function with given fields: _a0
func (_m *Hasher) Hash(_a0 string) (string, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for Hash")
}
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func(string) (string, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(string) string); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewHasher creates a new instance of Hasher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewHasher(t interface {
mock.TestingT
Cleanup(func())
}) *Hasher {
mock := &Hasher{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+404
View File
@@ -0,0 +1,404 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
auth "github.com/absmach/supermq/auth"
mock "github.com/stretchr/testify/mock"
time "time"
)
// PATS is an autogenerated mock type for the PATS type
type PATS struct {
mock.Mock
}
// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for AddPATScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID
func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for ClearPATAllScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope
func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration, scope)
if len(ret) == 0 {
panic("no return value specified for CreatePAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration, scope)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration, scope)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok {
r1 = rf(ctx, token, name, description, duration, scope)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeletePAT provides a mock function with given fields: ctx, token, patID
func (_m *PATS) DeletePAT(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for DeletePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// IdentifyPAT provides a mock function with given fields: ctx, paToken
func (_m *PATS) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
ret := _m.Called(ctx, paToken)
if len(ret) == 0 {
panic("no return value specified for IdentifyPAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok {
return rf(ctx, paToken)
}
if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok {
r0 = rf(ctx, paToken)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, paToken)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListPATS provides a mock function with given fields: ctx, token, pm
func (_m *PATS) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
ret := _m.Called(ctx, token, pm)
if len(ret) == 0 {
panic("no return value specified for ListPATS")
}
var r0 auth.PATSPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok {
return rf(ctx, token, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok {
r0 = rf(ctx, token, pm)
} else {
r0 = ret.Get(0).(auth.PATSPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok {
r1 = rf(ctx, token, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemovePATScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration
func (_m *PATS) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, duration)
if len(ret) == 0 {
panic("no return value specified for ResetPATSecret")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok {
return rf(ctx, token, patID, duration)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok {
r0 = rf(ctx, token, patID, duration)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok {
r1 = rf(ctx, token, patID, duration)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrievePAT provides a mock function with given fields: ctx, userID, patID
func (_m *PATS) RetrievePAT(ctx context.Context, userID string, patID string) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for RetrievePAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok {
return rf(ctx, userID, patID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, userID, patID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RevokePATSecret provides a mock function with given fields: ctx, token, patID
func (_m *PATS) RevokePATSecret(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for RevokePATSecret")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description
func (_m *PATS) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, description)
if len(ret) == 0 {
panic("no return value specified for UpdatePATDescription")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, token, patID, description)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, token, patID, description)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, token, patID, description)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdatePATName provides a mock function with given fields: ctx, token, patID, name
func (_m *PATS) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, name)
if len(ret) == 0 {
panic("no return value specified for UpdatePATName")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, token, patID, name)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, token, patID, name)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, token, patID, name)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewPATS creates a new instance of PATS. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPATS(t interface {
mock.TestingT
Cleanup(func())
}) *PATS {
mock := &PATS{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+401
View File
@@ -0,0 +1,401 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
auth "github.com/absmach/supermq/auth"
mock "github.com/stretchr/testify/mock"
time "time"
)
// PATSRepository is an autogenerated mock type for the PATSRepository type
type PATSRepository struct {
mock.Mock
}
// AddScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) AddScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for AddScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CheckScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) CheckScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// Reactivate provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) Reactivate(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for Reactivate")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Remove provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) Remove(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for Remove")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveAllScopeEntry provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for RemoveAllScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) RemoveScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemoveScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Retrieve provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) Retrieve(ctx context.Context, userID string, patID string) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for Retrieve")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok {
return rf(ctx, userID, patID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, userID, patID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveAll provides a mock function with given fields: ctx, userID, pm
func (_m *PATSRepository) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
ret := _m.Called(ctx, userID, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 auth.PATSPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok {
return rf(ctx, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok {
r0 = rf(ctx, userID, pm)
} else {
r0 = ret.Get(0).(auth.PATSPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok {
r1 = rf(ctx, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveSecretAndRevokeStatus provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) RetrieveSecretAndRevokeStatus(ctx context.Context, userID string, patID string) (string, bool, bool, error) {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for RetrieveSecretAndRevokeStatus")
}
var r0 string
var r1 bool
var r2 bool
var r3 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (string, bool, bool, error)); ok {
return rf(ctx, userID, patID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) string); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Get(0).(string)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) bool); ok {
r1 = rf(ctx, userID, patID)
} else {
r1 = ret.Get(1).(bool)
}
if rf, ok := ret.Get(2).(func(context.Context, string, string) bool); ok {
r2 = rf(ctx, userID, patID)
} else {
r2 = ret.Get(2).(bool)
}
if rf, ok := ret.Get(3).(func(context.Context, string, string) error); ok {
r3 = rf(ctx, userID, patID)
} else {
r3 = ret.Error(3)
}
return r0, r1, r2, r3
}
// Revoke provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) Revoke(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for Revoke")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Save provides a mock function with given fields: ctx, pat
func (_m *PATSRepository) Save(ctx context.Context, pat auth.PAT) error {
ret := _m.Called(ctx, pat)
if len(ret) == 0 {
panic("no return value specified for Save")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, auth.PAT) error); ok {
r0 = rf(ctx, pat)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateDescription provides a mock function with given fields: ctx, userID, patID, description
func (_m *PATSRepository) UpdateDescription(ctx context.Context, userID string, patID string, description string) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID, description)
if len(ret) == 0 {
panic("no return value specified for UpdateDescription")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, userID, patID, description)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, userID, patID, description)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, userID, patID, description)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateName provides a mock function with given fields: ctx, userID, patID, name
func (_m *PATSRepository) UpdateName(ctx context.Context, userID string, patID string, name string) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID, name)
if len(ret) == 0 {
panic("no return value specified for UpdateName")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, userID, patID, name)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, userID, patID, name)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, userID, patID, name)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateTokenHash provides a mock function with given fields: ctx, userID, patID, tokenHash, expiryAt
func (_m *PATSRepository) UpdateTokenHash(ctx context.Context, userID string, patID string, tokenHash string, expiryAt time.Time) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID, tokenHash, expiryAt)
if len(ret) == 0 {
panic("no return value specified for UpdateTokenHash")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) (auth.PAT, error)); ok {
return rf(ctx, userID, patID, tokenHash, expiryAt)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) auth.PAT); ok {
r0 = rf(ctx, userID, patID, tokenHash, expiryAt)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Time) error); ok {
r1 = rf(ctx, userID, patID, tokenHash, expiryAt)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewPATSRepository creates a new instance of PATSRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPATSRepository(t interface {
mock.TestingT
Cleanup(func())
}) *PATSRepository {
mock := &PATSRepository{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+372
View File
@@ -12,6 +12,8 @@ import (
mock "github.com/stretchr/testify/mock"
policies "github.com/absmach/supermq/pkg/policies"
time "time"
)
// Service is an autogenerated mock type for the Service type
@@ -19,6 +21,41 @@ type Service struct {
mock.Mock
}
// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for AddPATScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Authorize provides a mock function with given fields: ctx, pr
func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error {
ret := _m.Called(ctx, pr)
@@ -37,6 +74,120 @@ func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error {
return r0
}
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID
func (_m *Service) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for ClearPATAllScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope
func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration, scope)
if len(ret) == 0 {
panic("no return value specified for CreatePAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration, scope)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration, scope)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok {
r1 = rf(ctx, token, name, description, duration, scope)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeletePAT provides a mock function with given fields: ctx, token, patID
func (_m *Service) DeletePAT(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for DeletePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Identify provides a mock function with given fields: ctx, token
func (_m *Service) Identify(ctx context.Context, token string) (auth.Key, error) {
ret := _m.Called(ctx, token)
@@ -65,6 +216,34 @@ func (_m *Service) Identify(ctx context.Context, token string) (auth.Key, error)
return r0, r1
}
// IdentifyPAT provides a mock function with given fields: ctx, paToken
func (_m *Service) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
ret := _m.Called(ctx, paToken)
if len(ret) == 0 {
panic("no return value specified for IdentifyPAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok {
return rf(ctx, paToken)
}
if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok {
r0 = rf(ctx, paToken)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, paToken)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Issue provides a mock function with given fields: ctx, token, key
func (_m *Service) Issue(ctx context.Context, token string, key auth.Key) (auth.Token, error) {
ret := _m.Called(ctx, token, key)
@@ -93,6 +272,97 @@ func (_m *Service) Issue(ctx context.Context, token string, key auth.Key) (auth.
return r0, r1
}
// ListPATS provides a mock function with given fields: ctx, token, pm
func (_m *Service) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
ret := _m.Called(ctx, token, pm)
if len(ret) == 0 {
panic("no return value specified for ListPATS")
}
var r0 auth.PATSPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok {
return rf(ctx, token, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok {
r0 = rf(ctx, token, pm)
} else {
r0 = ret.Get(0).(auth.PATSPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok {
r1 = rf(ctx, token, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemovePATScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration
func (_m *Service) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, duration)
if len(ret) == 0 {
panic("no return value specified for ResetPATSecret")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok {
return rf(ctx, token, patID, duration)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok {
r0 = rf(ctx, token, patID, duration)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok {
r1 = rf(ctx, token, patID, duration)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveKey provides a mock function with given fields: ctx, token, id
func (_m *Service) RetrieveKey(ctx context.Context, token string, id string) (auth.Key, error) {
ret := _m.Called(ctx, token, id)
@@ -121,6 +391,34 @@ func (_m *Service) RetrieveKey(ctx context.Context, token string, id string) (au
return r0, r1
}
// RetrievePAT provides a mock function with given fields: ctx, userID, patID
func (_m *Service) RetrievePAT(ctx context.Context, userID string, patID string) (auth.PAT, error) {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for RetrievePAT")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok {
return rf(ctx, userID, patID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, userID, patID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Revoke provides a mock function with given fields: ctx, token, id
func (_m *Service) Revoke(ctx context.Context, token string, id string) error {
ret := _m.Called(ctx, token, id)
@@ -139,6 +437,80 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error {
return r0
}
// RevokePATSecret provides a mock function with given fields: ctx, token, patID
func (_m *Service) RevokePATSecret(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for RevokePATSecret")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description
func (_m *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, description)
if len(ret) == 0 {
panic("no return value specified for UpdatePATDescription")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, token, patID, description)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, token, patID, description)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, token, patID, description)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdatePATName provides a mock function with given fields: ctx, token, patID, name
func (_m *Service) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, name)
if len(ret) == 0 {
panic("no return value specified for UpdatePATName")
}
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok {
return rf(ctx, token, patID, name)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok {
r0 = rf(ctx, token, patID, name)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, token, patID, name)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
+804
View File
@@ -0,0 +1,804 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/absmach/supermq/pkg/errors"
)
var errAddEntityToAnyIDs = errors.New("could not add entity id to any ID scope value")
// Define OperationType.
type OperationType uint32
const (
CreateOp OperationType = iota
ReadOp
ListOp
UpdateOp
DeleteOp
ShareOp
UnshareOp
PublishOp
SubscribeOp
)
const (
createOpStr = "create"
readOpStr = "read"
listOpStr = "list"
updateOpStr = "update"
deleteOpStr = "delete"
shareOpStr = "share"
UnshareOpStr = "unshare"
PublishOpStr = "publish"
SubscribeOpStr = "subscribe"
)
func (ot OperationType) String() string {
switch ot {
case CreateOp:
return createOpStr
case ReadOp:
return readOpStr
case ListOp:
return listOpStr
case UpdateOp:
return updateOpStr
case DeleteOp:
return deleteOpStr
case ShareOp:
return shareOpStr
case UnshareOp:
return UnshareOpStr
case PublishOp:
return PublishOpStr
case SubscribeOp:
return SubscribeOpStr
default:
return fmt.Sprintf("unknown operation type %d", ot)
}
}
func (ot OperationType) ValidString() (string, error) {
str := ot.String()
if str == fmt.Sprintf("unknown operation type %d", ot) {
return "", errors.New(str)
}
return str, nil
}
func ParseOperationType(ot string) (OperationType, error) {
switch ot {
case createOpStr:
return CreateOp, nil
case readOpStr:
return ReadOp, nil
case listOpStr:
return ListOp, nil
case updateOpStr:
return UpdateOp, nil
case deleteOpStr:
return DeleteOp, nil
case shareOpStr:
return ShareOp, nil
case UnshareOpStr:
return UnshareOp, nil
case PublishOpStr:
return PublishOp, nil
case SubscribeOpStr:
return SubscribeOp, nil
default:
return 0, fmt.Errorf("unknown operation type %s", ot)
}
}
func (ot OperationType) MarshalJSON() ([]byte, error) {
return []byte(ot.String()), nil
}
func (ot OperationType) MarshalText() (text []byte, err error) {
return []byte(ot.String()), nil
}
func (ot *OperationType) UnmarshalText(data []byte) (err error) {
*ot, err = ParseOperationType(string(data))
return err
}
// Define DomainEntityType.
type DomainEntityType uint32
const (
DomainManagementScope DomainEntityType = iota
DomainGroupsScope
DomainChannelsScope
DomainClientsScope
DomainNullScope
)
const (
domainManagementScopeStr = "domain_management"
domainGroupsScopeStr = "groups"
domainChannelsScopeStr = "channels"
domainClientsScopeStr = "clients"
)
func (det DomainEntityType) String() string {
switch det {
case DomainManagementScope:
return domainManagementScopeStr
case DomainGroupsScope:
return domainGroupsScopeStr
case DomainChannelsScope:
return domainChannelsScopeStr
case DomainClientsScope:
return domainClientsScopeStr
default:
return fmt.Sprintf("unknown domain entity type %d", det)
}
}
func (det DomainEntityType) ValidString() (string, error) {
str := det.String()
if str == fmt.Sprintf("unknown operation type %d", det) {
return "", errors.New(str)
}
return str, nil
}
func ParseDomainEntityType(det string) (DomainEntityType, error) {
switch det {
case domainManagementScopeStr:
return DomainManagementScope, nil
case domainGroupsScopeStr:
return DomainGroupsScope, nil
case domainChannelsScopeStr:
return DomainChannelsScope, nil
case domainClientsScopeStr:
return DomainClientsScope, nil
default:
return 0, fmt.Errorf("unknown domain entity type %s", det)
}
}
func (det DomainEntityType) MarshalJSON() ([]byte, error) {
return []byte(det.String()), nil
}
func (det DomainEntityType) MarshalText() ([]byte, error) {
return []byte(det.String()), nil
}
func (det *DomainEntityType) UnmarshalText(data []byte) (err error) {
*det, err = ParseDomainEntityType(string(data))
return err
}
// Define DomainEntityType.
type PlatformEntityType uint32
const (
PlatformUsersScope PlatformEntityType = iota
PlatformDomainsScope
PlatformDashBoardScope
PlatformMesagingScope
)
const (
platformUsersScopeStr = "users"
platformDomainsScopeStr = "domains"
PlatformDashBoardScopeStr = "dashboard"
PlatformMesagingScopeStr = "messaging"
)
func (pet PlatformEntityType) String() string {
switch pet {
case PlatformUsersScope:
return platformUsersScopeStr
case PlatformDomainsScope:
return platformDomainsScopeStr
case PlatformDashBoardScope:
return PlatformDashBoardScopeStr
case PlatformMesagingScope:
return PlatformMesagingScopeStr
default:
return fmt.Sprintf("unknown platform entity type %d", pet)
}
}
func (pet PlatformEntityType) ValidString() (string, error) {
str := pet.String()
if str == fmt.Sprintf("unknown platform entity type %d", pet) {
return "", errors.New(str)
}
return str, nil
}
func ParsePlatformEntityType(pet string) (PlatformEntityType, error) {
switch pet {
case platformUsersScopeStr:
return PlatformUsersScope, nil
case platformDomainsScopeStr:
return PlatformDomainsScope, nil
default:
return 0, fmt.Errorf("unknown platform entity type %s", pet)
}
}
func (pet PlatformEntityType) MarshalJSON() ([]byte, error) {
return []byte(pet.String()), nil
}
func (pet PlatformEntityType) MarshalText() (text []byte, err error) {
return []byte(pet.String()), nil
}
func (pet *PlatformEntityType) UnmarshalText(data []byte) (err error) {
*pet, err = ParsePlatformEntityType(string(data))
return err
}
// ScopeValue interface for Any entity ids or for sets of entity ids.
type ScopeValue interface {
Contains(id string) bool
Values() []string
AddValues(ids ...string) error
RemoveValues(ids ...string) error
}
// AnyIDs implements ScopeValue for any entity id value.
type AnyIDs struct{}
func (s AnyIDs) Contains(id string) bool { return true }
func (s AnyIDs) Values() []string { return []string{"*"} }
func (s *AnyIDs) AddValues(ids ...string) error { return errAddEntityToAnyIDs }
func (s *AnyIDs) RemoveValues(ids ...string) error { return errAddEntityToAnyIDs }
// SelectedIDs implements ScopeValue for sets of entity ids.
type SelectedIDs map[string]struct{}
func (s SelectedIDs) Contains(id string) bool { _, ok := s[id]; return ok }
func (s SelectedIDs) Values() []string {
values := []string{}
for value := range s {
values = append(values, value)
}
return values
}
func (s *SelectedIDs) AddValues(ids ...string) error {
if *s == nil {
*s = make(SelectedIDs)
}
for _, id := range ids {
(*s)[id] = struct{}{}
}
return nil
}
func (s *SelectedIDs) RemoveValues(ids ...string) error {
if *s == nil {
return nil
}
for _, id := range ids {
delete(*s, id)
}
return nil
}
// OperationScope contains map of OperationType with value of AnyIDs or SelectedIDs.
type OperationScope map[OperationType]ScopeValue
func (os *OperationScope) UnmarshalJSON(data []byte) error {
type tempOperationScope map[OperationType]json.RawMessage
var tempScope tempOperationScope
if err := json.Unmarshal(data, &tempScope); err != nil {
return err
}
// Initialize the Operations map
*os = OperationScope{}
for opType, rawMessage := range tempScope {
var stringValue string
var stringArrayValue []string
// Try to unmarshal as string
if err := json.Unmarshal(rawMessage, &stringValue); err == nil {
if err := os.Add(opType, stringValue); err != nil {
return err
}
continue
}
// Try to unmarshal as []string
if err := json.Unmarshal(rawMessage, &stringArrayValue); err == nil {
if err := os.Add(opType, stringArrayValue...); err != nil {
return err
}
continue
}
// If neither unmarshalling succeeded, return an error
return fmt.Errorf("invalid ScopeValue for OperationType %v", opType)
}
return nil
}
func (os OperationScope) MarshalJSON() ([]byte, error) {
tempOperationScope := make(map[OperationType]interface{})
for oType, scope := range os {
value := scope.Values()
if len(value) == 1 && value[0] == "*" {
tempOperationScope[oType] = "*"
continue
}
tempOperationScope[oType] = value
}
b, err := json.Marshal(tempOperationScope)
if err != nil {
return nil, err
}
return b, nil
}
func (os *OperationScope) Add(operation OperationType, entityIDs ...string) error {
var value ScopeValue
if os == nil {
os = &OperationScope{}
}
if len(entityIDs) == 0 {
return fmt.Errorf("entity ID is missing")
}
switch {
case len(entityIDs) == 1 && entityIDs[0] == "*":
value = &AnyIDs{}
default:
var sids SelectedIDs
for _, entityID := range entityIDs {
if entityID == "*" {
return fmt.Errorf("list contains wildcard")
}
if sids == nil {
sids = make(SelectedIDs)
}
sids[entityID] = struct{}{}
}
value = &sids
}
(*os)[operation] = value
return nil
}
func (os *OperationScope) Delete(operation OperationType, entityIDs ...string) error {
if os == nil {
return nil
}
opEntityIDs, exists := (*os)[operation]
if !exists {
return nil
}
if len(entityIDs) == 0 {
return fmt.Errorf("failed to delete operation %s: entity ID is missing", operation.String())
}
switch eIDs := opEntityIDs.(type) {
case *AnyIDs:
if !(len(entityIDs) == 1 && entityIDs[0] == "*") {
return fmt.Errorf("failed to delete operation %s: invalid list", operation.String())
}
delete((*os), operation)
return nil
case *SelectedIDs:
for _, entityID := range entityIDs {
if !eIDs.Contains(entityID) {
return fmt.Errorf("failed to delete operation %s: invalid entity ID in list", operation.String())
}
}
for _, entityID := range entityIDs {
delete(*eIDs, entityID)
if len(*eIDs) == 0 {
delete((*os), operation)
}
}
return nil
default:
return fmt.Errorf("failed to delete operation: invalid entity id type %d", operation)
}
}
func (os *OperationScope) Check(operation OperationType, entityIDs ...string) bool {
if os == nil {
return false
}
if scopeValue, ok := (*os)[operation]; ok {
if len(entityIDs) == 0 {
_, ok := scopeValue.(*AnyIDs)
return ok
}
for _, entityID := range entityIDs {
if !scopeValue.Contains(entityID) {
return false
}
}
return true
}
return false
}
type DomainScope struct {
DomainManagement OperationScope `json:"domain_management,omitempty"`
Entities map[DomainEntityType]OperationScope `json:"entities,omitempty"`
}
// Add entry in Domain scope.
func (ds *DomainScope) Add(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if ds == nil {
return fmt.Errorf("failed to add domain %s scope: domain_scope is nil and not initialized", domainEntityType)
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return fmt.Errorf("failed to add domain %d scope: invalid domain entity type", domainEntityType)
}
if domainEntityType == DomainManagementScope {
if err := ds.DomainManagement.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain management scope: %w", err)
}
}
if ds.Entities == nil {
ds.Entities = make(map[DomainEntityType]OperationScope)
}
opReg, ok := ds.Entities[domainEntityType]
if !ok {
opReg = OperationScope{}
}
if err := opReg.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add domain %s scope: %w ", domainEntityType.String(), err)
}
ds.Entities[domainEntityType] = opReg
return nil
}
// Delete entry in Domain scope.
func (ds *DomainScope) Delete(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if ds == nil {
return nil
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return fmt.Errorf("failed to delete domain %d scope: invalid domain entity type", domainEntityType)
}
if ds.Entities == nil {
return nil
}
if domainEntityType == DomainManagementScope {
if err := ds.DomainManagement.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain management scope: %w", err)
}
}
os, exists := ds.Entities[domainEntityType]
if !exists {
return nil
}
if err := os.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain %s scope: %w", domainEntityType.String(), err)
}
if len(os) == 0 {
delete(ds.Entities, domainEntityType)
}
return nil
}
// Check entry in Domain scope.
func (ds *DomainScope) Check(domainEntityType DomainEntityType, operation OperationType, ids ...string) bool {
if ds.Entities == nil {
return false
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return false
}
if domainEntityType == DomainManagementScope {
return ds.DomainManagement.Check(operation, ids...)
}
os, exists := ds.Entities[domainEntityType]
if !exists {
return false
}
return os.Check(operation, ids...)
}
// Example Scope as JSON
//
// {
// "users": {
// "create": ["*"],
// "read": ["*"],
// "list": ["*"],
// "update": ["*"],
// "delete": ["*"]
// },
// "domains": {
// "domain_1": {
// "entities": {
// "groups": {
// "create": ["*"] // this for all groups in domain
// },
// "channels": {
// // for particular channel in domain
// "delete": [
// "channel1",
// "channel2"
// ]
// },
// "things": {
// "update": ["*"] // this for all things in domain
// }
// }
// }
// }
// }
type Scope struct {
Users OperationScope `json:"users,omitempty"`
Domains map[string]DomainScope `json:"domains,omitempty"`
Dashboard OperationScope `json:"dashboard,omitempty"`
Messaging OperationScope `json:"messaging,omitempty"`
}
// Add entry in Domain scope.
func (s *Scope) Add(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if s == nil {
return fmt.Errorf("failed to add platform %s scope: scope is nil and not initialized", platformEntityType.String())
}
switch platformEntityType {
case PlatformUsersScope:
if err := s.Users.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDashBoardScope:
if err := s.Dashboard.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformMesagingScope:
if err := s.Messaging.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDomainsScope:
if optionalDomainID == "" {
return fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String())
}
if len(s.Domains) == 0 {
s.Domains = make(map[string]DomainScope)
}
ds, ok := s.Domains[optionalDomainID]
if !ok {
ds = DomainScope{}
}
if err := ds.Add(optionalDomainEntityType, operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err)
}
s.Domains[optionalDomainID] = ds
default:
return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType)
}
return nil
}
// Delete entry in Domain scope.
func (s *Scope) Delete(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if s == nil {
return nil
}
switch platformEntityType {
case PlatformUsersScope:
if err := s.Users.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDashBoardScope:
if err := s.Dashboard.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformMesagingScope:
if err := s.Messaging.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDomainsScope:
if optionalDomainID == "" {
return fmt.Errorf("failed to delete platform %s scope: invalid domain id", platformEntityType.String())
}
ds, ok := s.Domains[optionalDomainID]
if !ok {
return nil
}
if err := ds.Delete(optionalDomainEntityType, operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err)
}
default:
return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType)
}
return nil
}
// Check entry in Domain scope.
func (s *Scope) Check(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) bool {
if s == nil {
return false
}
switch platformEntityType {
case PlatformUsersScope:
return s.Users.Check(operation, entityIDs...)
case PlatformDashBoardScope:
return s.Dashboard.Check(operation, entityIDs...)
case PlatformMesagingScope:
return s.Messaging.Check(operation, entityIDs...)
case PlatformDomainsScope:
ds, ok := s.Domains[optionalDomainID]
if !ok {
return false
}
return ds.Check(optionalDomainEntityType, operation, entityIDs...)
default:
return false
}
}
func (s *Scope) String() string {
str, err := json.Marshal(s) // , "", " ")
if err != nil {
return fmt.Sprintf("failed to convert scope to string: json marshal error :%s", err.Error())
}
return string(str)
}
// PAT represents Personal Access Token.
type PAT struct {
ID string `json:"id,omitempty"`
User string `json:"user,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Secret string `json:"secret,omitempty"`
Scope Scope `json:"scope,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
LastUsedAt time.Time `json:"last_used_at,omitempty"`
Revoked bool `json:"revoked,omitempty"`
RevokedAt time.Time `json:"revoked_at,omitempty"`
}
type PATSPageMeta struct {
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
}
type PATSPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
PATS []PAT `json:"pats"`
}
func (pat *PAT) String() string {
str, err := json.MarshalIndent(pat, "", " ")
if err != nil {
return fmt.Sprintf("failed to convert PAT to string: json marshal error :%s", err.Error())
}
return string(str)
}
// Expired verifies if the key is expired.
func (pat PAT) Expired() bool {
return pat.ExpiresAt.UTC().Before(time.Now().UTC())
}
// PATS specifies function which are required for Personal access Token implementation.
//go:generate mockery --name PATS --output=./mocks --filename pats.go --quiet --note "Copyright (c) Abstract Machines"
type PATS interface {
// Create function creates new PAT for given valid inputs.
CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error)
// UpdateName function updates the name for the given PAT ID.
UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error)
// UpdateDescription function updates the description for the given PAT ID.
UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error)
// Retrieve function retrieves the PAT for given ID.
RetrievePAT(ctx context.Context, userID string, patID string) (PAT, error)
// List function lists all the PATs for the user.
ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error)
// Delete function deletes the PAT for given ID.
DeletePAT(ctx context.Context, token, patID string) error
// ResetSecret function reset the secret and creates new secret for the given ID.
ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error)
// RevokeSecret function revokes the secret for the given ID.
RevokePATSecret(ctx context.Context, token, patID string) error
// AddScope function adds a new scope entry.
AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
// RemoveScope function removes a scope entry.
RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
// ClearAllScope function removes all scope entry.
ClearPATAllScopeEntry(ctx context.Context, token, patID string) error
// IdentifyPAT function will valid the secret.
IdentifyPAT(ctx context.Context, paToken string) (PAT, error)
// AuthorizePAT function will valid the secret and check the given scope exists.
AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
// CheckPAT function will check the given scope exists.
CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
}
// PATSRepository specifies PATS persistence API.
//
//go:generate mockery --name PATSRepository --output=./mocks --filename patsrepo.go --quiet --note "Copyright (c) Abstract Machines"
type PATSRepository interface {
// Save persists the PAT
Save(ctx context.Context, pat PAT) (err error)
// Retrieve retrieves users PAT by its unique identifier.
Retrieve(ctx context.Context, userID, patID string) (pat PAT, err error)
// RetrieveSecretAndRevokeStatus retrieves secret and revoke status of PAT by its unique identifier.
RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error)
// UpdateName updates the name of a PAT.
UpdateName(ctx context.Context, userID, patID, name string) (PAT, error)
// UpdateDescription updates the description of a PAT.
UpdateDescription(ctx context.Context, userID, patID, description string) (PAT, error)
// UpdateTokenHash updates the token hash of a PAT.
UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (PAT, error)
// RetrieveAll retrieves all PATs belongs to userID.
RetrieveAll(ctx context.Context, userID string, pm PATSPageMeta) (pats PATSPage, err error)
// Revoke PAT with provided ID.
Revoke(ctx context.Context, userID, patID string) error
// Reactivate PAT with provided ID.
Reactivate(ctx context.Context, userID, patID string) error
// Remove removes Key with provided ID.
Remove(ctx context.Context, userID, patID string) error
AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
RemoveAllScopeEntry(ctx context.Context, userID, patID string) error
}
+281 -3
View File
@@ -5,6 +5,8 @@ package auth
import (
"context"
"encoding/base64"
"math/rand"
"strings"
"time"
@@ -12,11 +14,15 @@ import (
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/google/uuid"
)
const (
recoveryDuration = 5 * time.Minute
defLimit = 100
recoveryDuration = 5 * time.Minute
defLimit = 100
randStr = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&&*|+-="
patPrefix = "pat"
patSecretSeparator = "_"
)
var (
@@ -29,6 +35,17 @@ var (
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errMalformedPAT = errors.New("malformed personal access token")
errFailedToParseUUID = errors.New("failed to parse string to UUID")
errInvalidLenFor2UUIDs = errors.New("invalid input length for 2 UUID, excepted 32 byte")
errRevokedPAT = errors.New("revoked pat")
errCreatePAT = errors.New("failed to create PAT")
errUpdatePAT = errors.New("failed to update PAT")
errRetrievePAT = errors.New("failed to retrieve PAT")
errDeletePAT = errors.New("failed to delete PAT")
errRevokePAT = errors.New("failed to revoke PAT")
errClearAllScope = errors.New("failed to clear all entry in scope")
)
// Authz represents a authorization service. It exposes
@@ -75,12 +92,15 @@ type Authn interface {
type Service interface {
Authn
Authz
PATS
}
var _ Service = (*service)(nil)
type service struct {
keys KeyRepository
pats PATSRepository
hasher Hasher
idProvider supermq.IDProvider
evaluator policies.Evaluator
policysvc policies.Service
@@ -91,10 +111,12 @@ type service struct {
}
// New instantiates the auth service implementation.
func New(keys KeyRepository, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
func New(keys KeyRepository, pats PATSRepository, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
return &service{
tokenizer: tokenizer,
keys: keys,
pats: pats,
hasher: hasher,
idProvider: idp,
evaluator: policyEvaluator,
policysvc: policyService,
@@ -434,3 +456,259 @@ func DecodeDomainUserID(domainUserID string) (string, string) {
return "", ""
}
}
func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
id, err := svc.idProvider.ID()
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
secret, hash, err := svc.generateSecretAndHash(key.User, id)
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
now := time.Now()
pat := PAT{
ID: id,
User: key.User,
Name: name,
Description: description,
Secret: hash,
IssuedAt: now,
ExpiresAt: now.Add(duration),
Scope: scope,
}
if err := svc.pats.Save(ctx, pat); err != nil {
return PAT{}, errors.Wrap(errCreatePAT, err)
}
pat.Secret = secret
return pat, nil
}
func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.UpdateName(ctx, key.User, patID, name)
if err != nil {
return PAT{}, errors.Wrap(errUpdatePAT, err)
}
return pat, nil
}
func (svc service) UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.UpdateDescription(ctx, key.User, patID, description)
if err != nil {
return PAT{}, errors.Wrap(errUpdatePAT, err)
}
return pat, nil
}
func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.Retrieve(ctx, key.User, patID)
if err != nil {
return PAT{}, errors.Wrap(errRetrievePAT, err)
}
return pat, nil
}
func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PATSPage{}, err
}
patsPage, err := svc.pats.RetrieveAll(ctx, key.User, pm)
if err != nil {
return PATSPage{}, errors.Wrap(errRetrievePAT, err)
}
return patsPage, nil
}
func (svc service) DeletePAT(ctx context.Context, token, patID string) error {
key, err := svc.Identify(ctx, token)
if err != nil {
return err
}
if err := svc.pats.Remove(ctx, key.User, patID); err != nil {
return errors.Wrap(errDeletePAT, err)
}
return nil
}
func (svc service) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
// Generate new HashToken take place here
secret, hash, err := svc.generateSecretAndHash(key.User, patID)
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
pat, err := svc.pats.UpdateTokenHash(ctx, key.User, patID, hash, time.Now().Add(duration))
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
if err := svc.pats.Reactivate(ctx, key.User, patID); err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
pat.Secret = secret
pat.Revoked = false
pat.RevokedAt = time.Time{}
return pat, nil
}
func (svc service) RevokePATSecret(ctx context.Context, token, patID string) error {
key, err := svc.Identify(ctx, token)
if err != nil {
return err
}
if err := svc.pats.Revoke(ctx, key.User, patID); err != nil {
return errors.Wrap(errRevokePAT, err)
}
return nil
}
func (svc service) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Scope{}, err
}
scope, err := svc.pats.AddScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return Scope{}, errors.Wrap(errRevokePAT, err)
}
return scope, nil
}
func (svc service) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Scope{}, err
}
scope, err := svc.pats.RemoveScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return Scope{}, err
}
return scope, nil
}
func (svc service) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
key, err := svc.Identify(ctx, token)
if err != nil {
return err
}
if err := svc.pats.RemoveAllScopeEntry(ctx, key.User, patID); err != nil {
return errors.Wrap(errClearAllScope, err)
}
return nil
}
func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error) {
parts := strings.Split(secret, patSecretSeparator)
if len(parts) != 3 && parts[0] != patPrefix {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT)
}
userID, patID, err := decode(parts[1])
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT)
}
secretHash, revoked, expired, err := svc.pats.RetrieveSecretAndRevokeStatus(ctx, userID.String(), patID.String())
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if revoked {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedPAT)
}
if expired {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, ErrExpiry)
}
if err := svc.hasher.Compare(secret, secretHash); err != nil {
return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return PAT{ID: patID.String(), User: userID.String()}, nil
}
func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
res, err := svc.RetrievePAT(ctx, userID, patID)
if err != nil {
return err
}
if err := svc.pats.CheckScopeEntry(ctx, res.User, res.ID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
func (svc service) CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if err := svc.pats.CheckScopeEntry(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
func (svc service) generateSecretAndHash(userID, patID string) (string, string, error) {
uID, err := uuid.Parse(userID)
if err != nil {
return "", "", errors.Wrap(errFailedToParseUUID, err)
}
pID, err := uuid.Parse(patID)
if err != nil {
return "", "", errors.Wrap(errFailedToParseUUID, err)
}
secret := patPrefix + patSecretSeparator + encode(uID, pID) + patSecretSeparator + generateRandomString(100)
secretHash, err := svc.hasher.Hash(secret)
return secret, secretHash, err
}
func encode(userID, patID uuid.UUID) string {
c := append(userID[:], patID[:]...)
return base64.StdEncoding.EncodeToString(c)
}
func decode(encoded string) (uuid.UUID, uuid.UUID, error) {
data, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return uuid.Nil, uuid.Nil, err
}
if len(data) != 32 {
return uuid.Nil, uuid.Nil, errInvalidLenFor2UUIDs
}
var userID, patID uuid.UUID
copy(userID[:], data[:16])
copy(patID[:], data[16:])
return userID, patID, nil
}
func generateRandomString(n int) string {
letterRunes := []rune(randStr)
rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}
+5 -1
View File
@@ -49,12 +49,16 @@ var (
krepo *mocks.KeyRepository
pService *policymocks.Service
pEvaluator *policymocks.Evaluator
patsrepo *mocks.PATSRepository
hasher *mocks.Hasher
)
func newService() (auth.Service, string) {
krepo = new(mocks.KeyRepository)
pService = new(policymocks.Service)
pEvaluator = new(policymocks.Evaluator)
patsrepo = new(mocks.PATSRepository)
hasher = new(mocks.Hasher)
idProvider := uuid.NewMock()
t := jwt.New([]byte(secret))
@@ -68,7 +72,7 @@ func newService() (auth.Service, string) {
}
token, _ := t.Issue(key)
return auth.New(krepo, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
return auth.New(krepo, patsrepo, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
}
func TestIssue(t *testing.T) {
+139
View File
@@ -6,6 +6,7 @@ package tracing
import (
"context"
"fmt"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/policies"
@@ -74,3 +75,141 @@ func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy)
return tm.svc.Authorize(ctx, pr)
}
func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "create_pat", trace.WithAttributes(
attribute.String("name", name),
attribute.String("description", description),
attribute.String("duration", duration.String()),
attribute.String("scope", scope.String()),
))
defer span.End()
return tm.svc.CreatePAT(ctx, token, name, description, duration, scope)
}
func (tm *tracingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "update_pat_name", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("name", name),
))
defer span.End()
return tm.svc.UpdatePATName(ctx, token, patID, name)
}
func (tm *tracingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "update_pat_description", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("description", description),
))
defer span.End()
return tm.svc.UpdatePATDescription(ctx, token, patID, description)
}
func (tm *tracingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "retrieve_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
))
defer span.End()
return tm.svc.RetrievePAT(ctx, token, patID)
}
func (tm *tracingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
ctx, span := tm.tracer.Start(ctx, "list_pat", trace.WithAttributes(
attribute.Int64("limit", int64(pm.Limit)),
attribute.Int64("offset", int64(pm.Offset)),
))
defer span.End()
return tm.svc.ListPATS(ctx, token, pm)
}
func (tm *tracingMiddleware) DeletePAT(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "delete_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
))
defer span.End()
return tm.svc.DeletePAT(ctx, token, patID)
}
func (tm *tracingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "reset_pat_secret", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("duration", duration.String()),
))
defer span.End()
return tm.svc.ResetPATSecret(ctx, token, patID, duration)
}
func (tm *tracingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "revoke_pat_secret", trace.WithAttributes(
attribute.String("pat_id", patID),
))
defer span.End()
return tm.svc.RevokePATSecret(ctx, token, patID)
}
func (tm *tracingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
ctx, span := tm.tracer.Start(ctx, "add_pat_scope_entry", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
defer span.End()
return tm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (tm *tracingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
ctx, span := tm.tracer.Start(ctx, "remove_pat_scope_entry", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
defer span.End()
return tm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (tm *tracingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope_entry", trace.WithAttributes(
attribute.String("pat_id", patID),
))
defer span.End()
return tm.svc.ClearPATAllScopeEntry(ctx, token, patID)
}
func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "identity_pat")
defer span.End()
return tm.svc.IdentifyPAT(ctx, paToken)
}
func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
defer span.End()
return tm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (tm *tracingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
ctx, span := tm.tracer.Start(ctx, "check_pat", trace.WithAttributes(
attribute.String("user_id", userID),
attribute.String("patID", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
defer span.End()
return tm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
+27 -5
View File
@@ -19,15 +19,17 @@ import (
authgrpcapi "github.com/absmach/supermq/auth/api/grpc/auth"
tokengrpcapi "github.com/absmach/supermq/auth/api/grpc/token"
httpapi "github.com/absmach/supermq/auth/api/http"
"github.com/absmach/supermq/auth/bolt"
"github.com/absmach/supermq/auth/hasher"
"github.com/absmach/supermq/auth/jwt"
apostgres "github.com/absmach/supermq/auth/postgres"
"github.com/absmach/supermq/auth/tracing"
boltclient "github.com/absmach/supermq/internal/clients/bolt"
grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1"
grpcTokenV1 "github.com/absmach/supermq/internal/grpc/token/v1"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/policies/spicedb"
"github.com/absmach/supermq/pkg/postgres"
pgclient "github.com/absmach/supermq/pkg/postgres"
"github.com/absmach/supermq/pkg/prometheus"
"github.com/absmach/supermq/pkg/server"
@@ -39,6 +41,7 @@ import (
"github.com/authzed/grpcutil"
"github.com/caarlos0/env/v11"
"github.com/jmoiron/sqlx"
"go.etcd.io/bbolt"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
@@ -51,6 +54,7 @@ const (
envPrefixHTTP = "SMQ_AUTH_HTTP_"
envPrefixGrpc = "SMQ_AUTH_GRPC_"
envPrefixDB = "SMQ_AUTH_DB_"
envPrefixPATDB = "SMQ_AUTH_PAT_DB_"
defDB = "auth"
defSvcHTTPPort = "8189"
defSvcGRPCPort = "8181"
@@ -131,7 +135,23 @@ func main() {
exitCode = 1
return
}
svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient)
boltDBConfig := boltclient.Config{}
if err := env.ParseWithOptions(&boltDBConfig, env.Options{Prefix: envPrefixPATDB}); err != nil {
logger.Error(fmt.Sprintf("failed to parse bolt db config : %s\n", err.Error()))
exitCode = 1
return
}
bClient, err := boltclient.Connect(boltDBConfig, bolt.Init)
if err != nil {
logger.Error(fmt.Sprintf("failed to connect to bolt db : %s\n", err.Error()))
exitCode = 1
return
}
defer bClient.Close()
svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, bClient, boltDBConfig)
grpcServerConfig := server.Config{Port: defSvcGRPCPort}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil {
@@ -211,9 +231,11 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
return nil
}
func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service {
database := postgres.NewDatabase(db, dbConfig, tracer)
func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, bClient *bbolt.DB, bConfig boltclient.Config) auth.Service {
database := pgclient.NewDatabase(db, dbConfig, tracer)
keysRepo := apostgres.New(database)
patsRepo := bolt.NewPATSRepository(bClient, bConfig.Bucket)
hasher := hasher.New()
idProvider := uuid.New()
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
@@ -221,7 +243,7 @@ func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config,
t := jwt.New([]byte(cfg.SecretKey))
svc := auth.New(keysRepo, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc := auth.New(keysRepo, patsRepo, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc = api.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics("auth", "api")
svc = api.MetricsMiddleware(svc, counter, latency)
+2
View File
@@ -17,6 +17,7 @@ volumes:
supermq-mqtt-broker-volume:
supermq-spicedb-db-volume:
supermq-auth-db-volume:
supermq-pat-db-volume:
supermq-domains-db-volume:
supermq-invitations-db-volume:
supermq-ui-db-volume:
@@ -136,6 +137,7 @@ services:
- supermq-base-net
volumes:
- ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE}
- supermq-pat-db-volume:/supermq-data
# Auth gRPC mTLS server certificates
- type: bind
source: ${SMQ_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert}
+2 -1
View File
@@ -18,6 +18,7 @@ require (
github.com/go-chi/chi/v5 v5.1.0
github.com/go-kit/kit v0.13.0
github.com/gofrs/uuid/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/gookit/color v1.5.4
github.com/gorilla/websocket v1.5.3
github.com/hashicorp/vault/api v1.15.0
@@ -44,6 +45,7 @@ require (
github.com/spf13/viper v1.19.0
github.com/sqids/sqids-go v0.4.1
github.com/stretchr/testify v1.10.0
go.etcd.io/bbolt v1.3.11
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0
go.opentelemetry.io/otel v1.33.0
@@ -98,7 +100,6 @@ require (
github.com/goccy/go-json v0.10.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect
+2
View File
@@ -465,6 +465,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0=
go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0 h1:qtFISDHKolvIxzSs0gIaiPUPR0Cucb0F2coHC7ZLdps=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.57.0/go.mod h1:Y+Pop1Q6hCOnETWTW4NROK/q1hv50hM7yDaUTjG8lp8=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU=
+8 -3
View File
@@ -7,6 +7,7 @@ import (
"context"
"net/http"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/apiutil"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/go-chi/chi/v5"
@@ -14,7 +15,9 @@ import (
type sessionKeyType string
const SessionKey = sessionKeyType("session")
const (
SessionKey = sessionKeyType("session")
)
func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
@@ -24,8 +27,10 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun
EncodeError(r.Context(), apiutil.ErrBearerToken, w)
return
}
var resp smqauthn.Session
var err error
resp, err := authn.Authenticate(r.Context(), token)
resp, err = authn.Authenticate(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
@@ -38,7 +43,7 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun
return
}
resp.DomainID = domain
resp.DomainUserID = domain + "_" + resp.UserID
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
+4 -1
View File
@@ -134,7 +134,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
case errors.Contains(err, svcerr.ErrAuthentication),
errors.Contains(err, apiutil.ErrBearerToken),
errors.Contains(err, svcerr.ErrLogin):
errors.Contains(err, svcerr.ErrLogin),
errors.Contains(err, apiutil.ErrUnsupportedTokenType):
err = unwrap(err)
w.WriteHeader(http.StatusUnauthorized)
case errors.Contains(err, svcerr.ErrMalformedEntity),
@@ -184,6 +185,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
errors.Contains(err, apiutil.ErrLenSearchQuery),
errors.Contains(err, apiutil.ErrMissingDomainID),
errors.Contains(err, certs.ErrFailedReadFromPKI),
errors.Contains(err, apiutil.ErrMissingUserID),
errors.Contains(err, apiutil.ErrMissingPATID),
errors.Contains(err, apiutil.ErrMissingUsername),
errors.Contains(err, apiutil.ErrMissingFirstName),
errors.Contains(err, apiutil.ErrMissingLastName),
+83
View File
@@ -0,0 +1,83 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bolt
import (
"io/fs"
"strconv"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/caarlos0/env/v11"
bolt "go.etcd.io/bbolt"
)
var (
errConfig = errors.New("failed to load BoltDB configuration")
errConnect = errors.New("failed to connect to BoltDB database")
errInit = errors.New("failed to initialize to BoltDB database")
)
type FileMode fs.FileMode
func (fm *FileMode) UnmarshalText(text []byte) error {
temp, err := strconv.ParseUint(string(text), 8, 32)
if err != nil {
return err
}
*fm = FileMode(temp)
return nil
}
// Config contains BoltDB specific parameters.
type Config struct {
FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./supermq-data"`
FileName string `env:"FILE_NAME" envDefault:"supermq-pat.db"`
FileMode FileMode `env:"FILE_MODE" envDefault:"0600"`
Bucket string `env:"BUCKET" envDefault:"supermq"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"0"`
}
// Setup load configuration from environment and creates new BoltDB.
func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
return SetupDB(envPrefix, initFn)
}
// SetupDB load configuration from environment,.
func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
cfg := Config{}
if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil {
return nil, errors.Wrap(errConfig, err)
}
bdb, err := Connect(cfg, initFn)
if err != nil {
return nil, err
}
return bdb, nil
}
// Connect establishes connection to the BoltDB.
func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
filePath := cfg.FileDirPath + "/" + cfg.FileName
db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil)
if err != nil {
return nil, errors.Wrap(errConnect, err)
}
if initFn != nil {
if err := Init(db, cfg, initFn); err != nil {
return nil, err
}
}
return db, nil
}
func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error {
if err := db.Update(func(tx *bolt.Tx) error {
return initFn(tx, cfg.Bucket)
}); err != nil {
return errors.Wrap(errInit, err)
}
return nil
}
+9
View File
@@ -0,0 +1,9 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package BoltDB contains the domain concept definitions needed to support
// Supermq BoltDB database functionality.
//
// It provides the abstraction of the BoltDB database service, which is used
// to configure, setup and connect to the BoltDB database.
package bolt
+151 -28
View File
@@ -238,6 +238,99 @@ func (x *AuthZReq) GetObjectType() string {
return ""
}
type AuthZPatReq struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id
PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id
PlatformEntityType uint32 `protobuf:"varint,3,opt,name=platform_entity_type,json=platformEntityType,proto3" json:"platform_entity_type,omitempty"` // Platform entity type
OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id
OptionalDomainEntityType uint32 `protobuf:"varint,5,opt,name=optional_domain_entity_type,json=optionalDomainEntityType,proto3" json:"optional_domain_entity_type,omitempty"` // Optional domain entity type
Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation
EntityIds []string `protobuf:"bytes,7,rep,name=entity_ids,json=entityIds,proto3" json:"entity_ids,omitempty"` // EntityIDs
}
func (x *AuthZPatReq) Reset() {
*x = AuthZPatReq{}
mi := &file_auth_v1_auth_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthZPatReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthZPatReq) ProtoMessage() {}
func (x *AuthZPatReq) ProtoReflect() protoreflect.Message {
mi := &file_auth_v1_auth_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use AuthZPatReq.ProtoReflect.Descriptor instead.
func (*AuthZPatReq) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{3}
}
func (x *AuthZPatReq) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
func (x *AuthZPatReq) GetPatId() string {
if x != nil {
return x.PatId
}
return ""
}
func (x *AuthZPatReq) GetPlatformEntityType() uint32 {
if x != nil {
return x.PlatformEntityType
}
return 0
}
func (x *AuthZPatReq) GetOptionalDomainId() string {
if x != nil {
return x.OptionalDomainId
}
return ""
}
func (x *AuthZPatReq) GetOptionalDomainEntityType() uint32 {
if x != nil {
return x.OptionalDomainEntityType
}
return 0
}
func (x *AuthZPatReq) GetOperation() uint32 {
if x != nil {
return x.Operation
}
return 0
}
func (x *AuthZPatReq) GetEntityIds() []string {
if x != nil {
return x.EntityIds
}
return nil
}
type AuthZRes struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -249,7 +342,7 @@ type AuthZRes struct {
func (x *AuthZRes) Reset() {
*x = AuthZRes{}
mi := &file_auth_v1_auth_proto_msgTypes[3]
mi := &file_auth_v1_auth_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -261,7 +354,7 @@ func (x *AuthZRes) String() string {
func (*AuthZRes) ProtoMessage() {}
func (x *AuthZRes) ProtoReflect() protoreflect.Message {
mi := &file_auth_v1_auth_proto_msgTypes[3]
mi := &file_auth_v1_auth_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -274,7 +367,7 @@ func (x *AuthZRes) ProtoReflect() protoreflect.Message {
// Deprecated: Use AuthZRes.ProtoReflect.Descriptor instead.
func (*AuthZRes) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{3}
return file_auth_v1_auth_proto_rawDescGZIP(), []int{4}
}
func (x *AuthZRes) GetAuthorized() bool {
@@ -321,22 +414,47 @@ var file_auth_v1_auth_proto_rawDesc = []byte{
0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f,
0x62, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x5f,
0x74, 0x79, 0x70, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6f, 0x62, 0x6a, 0x65,
0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52,
0x65, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64,
0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a,
0x65, 0x64, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02,
0x69, 0x64, 0x32, 0x7a, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x12, 0x33, 0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11,
0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x99, 0x02, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a,
0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12,
0x15, 0x0a, 0x06, 0x70, 0x61, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f,
0x72, 0x6d, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03,
0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x45, 0x6e,
0x74, 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04,
0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x3d, 0x0a, 0x1b, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e,
0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x6f, 0x70, 0x74,
0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74,
0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74,
0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64,
0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49,
0x64, 0x73, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e,
0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e,
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0,
0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33,
0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75,
0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11,
0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65,
0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65,
0x50, 0x41, 0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75,
0x74, 0x68, 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36,
0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11,
0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65,
0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68,
0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e,
0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31,
0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42, 0x32,
0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x62, 0x73,
0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, 0x69, 0x6e, 0x74,
0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f,
0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e,
0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61,
0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22,
0x00, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x61, 0x62, 0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f,
0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75,
0x74, 0x68, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@@ -351,20 +469,25 @@ func file_auth_v1_auth_proto_rawDescGZIP() []byte {
return file_auth_v1_auth_proto_rawDescData
}
var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_auth_v1_auth_proto_goTypes = []any{
(*AuthNReq)(nil), // 0: auth.v1.AuthNReq
(*AuthNRes)(nil), // 1: auth.v1.AuthNRes
(*AuthZReq)(nil), // 2: auth.v1.AuthZReq
(*AuthZRes)(nil), // 3: auth.v1.AuthZRes
(*AuthNReq)(nil), // 0: auth.v1.AuthNReq
(*AuthNRes)(nil), // 1: auth.v1.AuthNRes
(*AuthZReq)(nil), // 2: auth.v1.AuthZReq
(*AuthZPatReq)(nil), // 3: auth.v1.AuthZPatReq
(*AuthZRes)(nil), // 4: auth.v1.AuthZRes
}
var file_auth_v1_auth_proto_depIdxs = []int32{
2, // 0: auth.v1.AuthService.Authorize:input_type -> auth.v1.AuthZReq
0, // 1: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq
3, // 2: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes
1, // 3: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
3, // 1: auth.v1.AuthService.AuthorizePAT:input_type -> auth.v1.AuthZPatReq
0, // 2: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq
0, // 3: auth.v1.AuthService.AuthenticatePAT:input_type -> auth.v1.AuthNReq
4, // 4: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes
4, // 5: auth.v1.AuthService.AuthorizePAT:output_type -> auth.v1.AuthZRes
1, // 6: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes
1, // 7: auth.v1.AuthService.AuthenticatePAT:output_type -> auth.v1.AuthNRes
4, // [4:8] is the sub-list for method output_type
0, // [0:4] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@@ -381,7 +504,7 @@ func file_auth_v1_auth_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_auth_v1_auth_proto_rawDesc,
NumEnums: 0,
NumMessages: 4,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},
+78 -2
View File
@@ -22,8 +22,10 @@ import (
const _ = grpc.SupportPackageIsVersion9
const (
AuthService_Authorize_FullMethodName = "/auth.v1.AuthService/Authorize"
AuthService_Authenticate_FullMethodName = "/auth.v1.AuthService/Authenticate"
AuthService_Authorize_FullMethodName = "/auth.v1.AuthService/Authorize"
AuthService_AuthorizePAT_FullMethodName = "/auth.v1.AuthService/AuthorizePAT"
AuthService_Authenticate_FullMethodName = "/auth.v1.AuthService/Authenticate"
AuthService_AuthenticatePAT_FullMethodName = "/auth.v1.AuthService/AuthenticatePAT"
)
// AuthServiceClient is the client API for AuthService service.
@@ -34,7 +36,9 @@ const (
// and authorization functionalities for SuperMQ services.
type AuthServiceClient interface {
Authorize(ctx context.Context, in *AuthZReq, opts ...grpc.CallOption) (*AuthZRes, error)
AuthorizePAT(ctx context.Context, in *AuthZPatReq, opts ...grpc.CallOption) (*AuthZRes, error)
Authenticate(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error)
AuthenticatePAT(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error)
}
type authServiceClient struct {
@@ -55,6 +59,16 @@ func (c *authServiceClient) Authorize(ctx context.Context, in *AuthZReq, opts ..
return out, nil
}
func (c *authServiceClient) AuthorizePAT(ctx context.Context, in *AuthZPatReq, opts ...grpc.CallOption) (*AuthZRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthZRes)
err := c.cc.Invoke(ctx, AuthService_AuthorizePAT_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthNRes)
@@ -65,6 +79,16 @@ func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts
return out, nil
}
func (c *authServiceClient) AuthenticatePAT(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthNRes)
err := c.cc.Invoke(ctx, AuthService_AuthenticatePAT_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// AuthServiceServer is the server API for AuthService service.
// All implementations must embed UnimplementedAuthServiceServer
// for forward compatibility.
@@ -73,7 +97,9 @@ func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts
// and authorization functionalities for SuperMQ services.
type AuthServiceServer interface {
Authorize(context.Context, *AuthZReq) (*AuthZRes, error)
AuthorizePAT(context.Context, *AuthZPatReq) (*AuthZRes, error)
Authenticate(context.Context, *AuthNReq) (*AuthNRes, error)
AuthenticatePAT(context.Context, *AuthNReq) (*AuthNRes, error)
mustEmbedUnimplementedAuthServiceServer()
}
@@ -87,9 +113,15 @@ type UnimplementedAuthServiceServer struct{}
func (UnimplementedAuthServiceServer) Authorize(context.Context, *AuthZReq) (*AuthZRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method Authorize not implemented")
}
func (UnimplementedAuthServiceServer) AuthorizePAT(context.Context, *AuthZPatReq) (*AuthZRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method AuthorizePAT not implemented")
}
func (UnimplementedAuthServiceServer) Authenticate(context.Context, *AuthNReq) (*AuthNRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method Authenticate not implemented")
}
func (UnimplementedAuthServiceServer) AuthenticatePAT(context.Context, *AuthNReq) (*AuthNRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method AuthenticatePAT not implemented")
}
func (UnimplementedAuthServiceServer) mustEmbedUnimplementedAuthServiceServer() {}
func (UnimplementedAuthServiceServer) testEmbeddedByValue() {}
@@ -129,6 +161,24 @@ func _AuthService_Authorize_Handler(srv interface{}, ctx context.Context, dec fu
return interceptor(ctx, in, info, handler)
}
func _AuthService_AuthorizePAT_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthZPatReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthServiceServer).AuthorizePAT(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AuthService_AuthorizePAT_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthServiceServer).AuthorizePAT(ctx, req.(*AuthZPatReq))
}
return interceptor(ctx, in, info, handler)
}
func _AuthService_Authenticate_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthNReq)
if err := dec(in); err != nil {
@@ -147,6 +197,24 @@ func _AuthService_Authenticate_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _AuthService_AuthenticatePAT_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthNReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthServiceServer).AuthenticatePAT(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AuthService_AuthenticatePAT_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthServiceServer).AuthenticatePAT(ctx, req.(*AuthNReq))
}
return interceptor(ctx, in, info, handler)
}
// AuthService_ServiceDesc is the grpc.ServiceDesc for AuthService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -158,10 +226,18 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Authorize",
Handler: _AuthService_Authorize_Handler,
},
{
MethodName: "AuthorizePAT",
Handler: _AuthService_AuthorizePAT_Handler,
},
{
MethodName: "Authenticate",
Handler: _AuthService_Authenticate_Handler,
},
{
MethodName: "AuthenticatePAT",
Handler: _AuthService_AuthenticatePAT_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "auth/v1/auth.proto",
+12
View File
@@ -10,7 +10,9 @@ option go_package = "github.com/absmach/supermq/internal/grpc/auth/v1";
// and authorization functionalities for SuperMQ services.
service AuthService {
rpc Authorize(AuthZReq) returns (AuthZRes) {}
rpc AuthorizePAT(AuthZPatReq) returns (AuthZRes) {}
rpc Authenticate(AuthNReq) returns (AuthNRes) {}
rpc AuthenticatePAT(AuthNReq) returns (AuthNRes) {}
}
@@ -36,6 +38,16 @@ message AuthZReq {
string object_type = 9; // Client, User, Group
}
message AuthZPatReq {
string user_id = 1; // User id
string pat_id = 2; // Pat id
uint32 platform_entity_type = 3; // Platform entity type
string optional_domain_id = 4; // Optional domain id
uint32 optional_domain_entity_type = 5; // Optional domain entity type
uint32 operation = 6; // Operation
repeated string entity_ids = 7; // EntityIDs
}
message AuthZRes {
bool authorized = 1;
string id = 2;
+12
View File
@@ -241,4 +241,16 @@ var (
ErrInvalidProfilePictureURL = errors.New("invalid profile picture url")
ErrMultipleEntitiesFilter = errors.New("multiple entities are provided in filter are not supported")
// ErrMissingDescription indicates missing description.
ErrMissingDescription = errors.New("missing description")
// ErrUnsupportedTokenType indicates that this type of token is not supported.
ErrUnsupportedTokenType = errors.New("unsupported content token type")
// ErrMissingUserID indicates missing user ID.
ErrMissingUserID = errors.New("missing user id")
// ErrMissingPATID indicates missing pat ID.
ErrMissingPATID = errors.New("missing pat id")
)
+22
View File
@@ -7,7 +7,29 @@ import (
"context"
)
type TokenType uint32
const (
// AccessToken represents token generated by user.
AccessToken TokenType = iota
// PersonalAccessToken represents token generated by user for automation.
PersonalAccessToken
)
func (t TokenType) String() string {
switch t {
case AccessToken:
return "access token"
case PersonalAccessToken:
return "pat"
default:
return "unknown"
}
}
type Session struct {
Type TokenType
ID string
DomainUserID string
UserID string
DomainID string
+13 -1
View File
@@ -5,6 +5,7 @@ package authsvc
import (
"context"
"strings"
"github.com/absmach/supermq/auth/api/grpc/auth"
grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1"
@@ -14,6 +15,8 @@ import (
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
const patPrefix = "pat_"
type authentication struct {
authSvcClient grpcAuthV1.AuthServiceClient
}
@@ -38,9 +41,18 @@ func NewAuthentication(ctx context.Context, cfg grpcclient.Config) (authn.Authen
}
func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
if strings.HasPrefix(token, patPrefix) {
res, err := a.authSvcClient.AuthenticatePAT(ctx, &grpcAuthV1.AuthNReq{Token: token})
if err != nil {
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
return authn.Session{Type: authn.PersonalAccessToken, ID: res.GetId(), UserID: res.GetUserId()}, nil
}
res, err := a.authSvcClient.Authenticate(ctx, &grpcAuthV1.AuthNReq{Token: token})
if err != nil {
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
return authn.Session{DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
return authn.Session{Type: authn.AccessToken, DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
}