SMQ-3233 - Move callout to seperate middleware (#3244)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2025-12-01 19:40:49 +03:00
committed by GitHub
parent e75ce59998
commit 27b72db52e
15 changed files with 1223 additions and 620 deletions
+7 -118
View File
@@ -6,13 +6,11 @@ package middleware
import (
"context"
"fmt"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -44,12 +42,11 @@ var (
var _ channels.Service = (*authorizationMiddleware)(nil)
type authorizationMiddleware struct {
svc channels.Service
repo channels.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
callout callout.Callout
svc channels.Service
repo channels.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
rolemw.RoleManagerAuthorizationMiddleware
}
@@ -60,7 +57,6 @@ func NewAuthorization(
authz smqauthz.Authorization,
channelsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission,
extOpPerm map[svcutil.ExternalOperation]svcutil.Permission,
callout callout.Callout,
) (channels.Service, error) {
opp := channels.NewOperationPerm()
if err := opp.AddOperationPermissionMap(channelsOpPerm); err != nil {
@@ -77,7 +73,8 @@ func NewAuthorization(
if err := extOpp.Validate(); err != nil {
return nil, err
}
ram, err := rolemw.NewAuthorization(policies.ChannelType, svc, authz, rolesOpPerm, callout)
ram, err := rolemw.NewAuthorization(policies.ChannelType, svc, authz, rolesOpPerm)
if err != nil {
return nil, err
}
@@ -89,7 +86,6 @@ func NewAuthorization(
RoleManagerAuthorizationMiddleware: ram,
opp: opp,
extOpp: extOpp,
callout: callout,
}, nil
}
@@ -130,15 +126,6 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
}
}
params := map[string]any{
"entities": chs,
"count": len(chs),
}
if err := am.callOut(ctx, session, channels.OpCreateChannel.String(channels.OperationNames), "", params); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, err
}
return am.svc.CreateChannels(ctx, session, chs...)
}
@@ -166,10 +153,6 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
return channels.Channel{}, errors.Wrap(err, errView)
}
if err := am.callOut(ctx, session, channels.OpViewChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.ViewChannel(ctx, session, id, withRoles)
}
@@ -191,14 +174,6 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
session.SuperAdmin = true
}
params := map[string]any{
"pagemeta": pm,
}
if err := am.callOut(ctx, session, channels.OpListChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return am.svc.ListChannels(ctx, session, pm)
}
@@ -219,15 +194,6 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session
return channels.ChannelsPage{}, errors.Wrap(err, errList)
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, channels.OpListUserChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return am.svc.ListUserChannels(ctx, session, userID, pm)
}
@@ -255,10 +221,6 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au
return channels.Channel{}, errors.Wrap(err, errUpdate)
}
if err := am.callOut(ctx, session, channels.OpUpdateChannel.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.UpdateChannel(ctx, session, channel)
}
@@ -286,10 +248,6 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio
return channels.Channel{}, errors.Wrap(err, errUpdateTags)
}
if err := am.callOut(ctx, session, channels.OpUpdateChannelTags.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.UpdateChannelTags(ctx, session, channel)
}
@@ -317,10 +275,6 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au
return channels.Channel{}, errors.Wrap(err, errEnable)
}
if err := am.callOut(ctx, session, channels.OpEnableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.EnableChannel(ctx, session, id)
}
@@ -348,10 +302,6 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a
return channels.Channel{}, errors.Wrap(err, errDisable)
}
if err := am.callOut(ctx, session, channels.OpDisableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.DisableChannel(ctx, session, id)
}
@@ -378,10 +328,6 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au
return errors.Wrap(err, errDelete)
}
if err := am.callOut(ctx, session, channels.OpDeleteChannel.String(channels.OperationNames), id, nil); err != nil {
return err
}
return am.svc.RemoveChannel(ctx, session, id)
}
@@ -436,16 +382,6 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
}
}
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := am.callOut(ctx, session, channels.OpConnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return am.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
@@ -501,16 +437,6 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
}
}
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := am.callOut(ctx, session, channels.OpDisconnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return am.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
@@ -548,14 +474,6 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
return errors.Wrap(err, errGroupSetChildChannels)
}
params := map[string]any{
"parent_group_id": parentGroupID,
}
if err := am.callOut(ctx, session, channels.OpSetParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
@@ -598,14 +516,6 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
return errors.Wrap(err, errGroupRemoveChildChannels)
}
params := map[string]any{
"parent_group_id": ch.ParentGroup,
}
if err := am.callOut(ctx, session, channels.OpRemoveParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveParentGroup(ctx, session, id)
}
return nil
@@ -656,24 +566,3 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
}
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ChannelType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+208
View File
@@ -0,0 +1,208 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ channels.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc channels.Service
repo channels.Repository
callout callout.Callout
rolemw.RoleManagerCalloutMiddleware
}
func NewCallout(svc channels.Service, repo channels.Repository, callout callout.Callout) (channels.Service, error) {
call, err := rolemw.NewCallout(policies.ChannelType, svc, callout)
if err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
repo: repo,
callout: callout,
RoleManagerCalloutMiddleware: call,
}, nil
}
func (cm *calloutMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
params := map[string]any{
"entities": chs,
"count": len(chs),
}
if err := cm.callOut(ctx, session, channels.OpCreateChannel.String(channels.OperationNames), "", params); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, err
}
return cm.svc.CreateChannels(ctx, session, chs...)
}
func (cm *calloutMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
if err := cm.callOut(ctx, session, channels.OpViewChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return cm.svc.ViewChannel(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
params := map[string]any{
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, channels.OpListChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return cm.svc.ListChannels(ctx, session, pm)
}
func (cm *calloutMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, channels.OpListUserChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return cm.svc.ListUserChannels(ctx, session, userID, pm)
}
func (cm *calloutMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := cm.callOut(ctx, session, channels.OpUpdateChannel.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return cm.svc.UpdateChannel(ctx, session, channel)
}
func (cm *calloutMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := cm.callOut(ctx, session, channels.OpUpdateChannelTags.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return cm.svc.UpdateChannelTags(ctx, session, channel)
}
func (cm *calloutMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := cm.callOut(ctx, session, channels.OpEnableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return cm.svc.EnableChannel(ctx, session, id)
}
func (cm *calloutMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := cm.callOut(ctx, session, channels.OpDisableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return cm.svc.DisableChannel(ctx, session, id)
}
func (cm *calloutMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if err := cm.callOut(ctx, session, channels.OpDeleteChannel.String(channels.OperationNames), id, nil); err != nil {
return err
}
return cm.svc.RemoveChannel(ctx, session, id)
}
func (cm *calloutMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, channels.OpConnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return cm.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
func (cm *calloutMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, channels.OpDisconnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return cm.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
params := map[string]any{
"parent_group_id": parentGroupID,
}
if err := cm.callOut(ctx, session, channels.OpSetParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return cm.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
ch, err := cm.repo.RetrieveByID(ctx, id)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
if ch.ParentGroup != "" {
params := map[string]any{
"parent_group_id": ch.ParentGroup,
}
if err := cm.callOut(ctx, session, channels.OpRemoveParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return cm.svc.RemoveParentGroup(ctx, session, id)
}
return nil
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ChannelType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := cm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+94
View File
@@ -1266,3 +1266,97 @@ func TestDeleteUserCmd(t *testing.T) {
})
}
}
func TestSearchUsersCmd(t *testing.T) {
sdkMock := new(sdkmocks.SDK)
cli.SetSDK(sdkMock)
usersCmd := cli.NewUsersCmd()
rootCmd := setFlags(usersCmd)
usersPage := mgsdk.UsersPage{
Users: []mgsdk.User{user},
PageRes: mgsdk.PageRes{
Total: 1,
Offset: 0,
Limit: 10,
},
}
cases := []struct {
desc string
args []string
sdkErr errors.SDKError
errLogMessage string
usersPage mgsdk.UsersPage
logType outputLog
}{
{
desc: "search users by username successfully",
args: []string{
"search",
"username=testuser",
validToken,
},
usersPage: usersPage,
logType: entityLog,
},
{
desc: "search users with missing token",
args: []string{
"search",
"username=testuser",
},
logType: usageLog,
},
{
desc: "search users with missing query",
args: []string{
"search",
validToken,
},
logType: usageLog,
},
{
desc: "search users with extra arguments",
args: []string{
"search",
"username=testuser",
validToken,
extraArg,
},
logType: usageLog,
},
{
desc: "search users with service error",
args: []string{
"search",
"username=testuser",
validToken,
},
sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest).Error()),
logType: errLog,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("SearchUsers", mock.Anything, mock.Anything, mock.Anything).Return(tc.usersPage, tc.sdkErr)
out := executeCommand(t, rootCmd, tc.args...)
switch tc.logType {
case entityLog:
var page mgsdk.UsersPage
err := json.Unmarshal([]byte(out), &page)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %v", err))
assert.Equal(t, tc.usersPage, page, fmt.Sprintf("%s unexpected response: expected %v got %v", tc.desc, tc.usersPage, page))
case errLog:
assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out))
case usageLog:
assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out))
}
sdkCall.Unset()
})
}
}
+6 -102
View File
@@ -5,13 +5,11 @@ package middleware
import (
"context"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
@@ -38,12 +36,11 @@ var (
var _ clients.Service = (*authorizationMiddleware)(nil)
type authorizationMiddleware struct {
svc clients.Service
repo clients.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
callout callout.Callout
svc clients.Service
repo clients.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
rolemw.RoleManagerAuthorizationMiddleware
}
@@ -55,7 +52,6 @@ func NewAuthorization(
repo clients.Repository,
clientsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission,
extOpPerm map[svcutil.ExternalOperation]svcutil.Permission,
callout callout.Callout,
) (clients.Service, error) {
opp := clients.NewOperationPerm()
if err := opp.AddOperationPermissionMap(clientsOpPerm); err != nil {
@@ -64,7 +60,7 @@ func NewAuthorization(
if err := opp.Validate(); err != nil {
return nil, err
}
ram, err := rolemw.NewAuthorization(policies.ClientType, svc, authz, rolesOpPerm, callout)
ram, err := rolemw.NewAuthorization(policies.ClientType, svc, authz, rolesOpPerm)
if err != nil {
return nil, err
}
@@ -83,7 +79,6 @@ func NewAuthorization(
opp: opp,
extOpp: extOpp,
RoleManagerAuthorizationMiddleware: ram,
callout: callout,
}, nil
}
@@ -110,15 +105,6 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au
return []clients.Client{}, []roles.RoleProvision{}, errors.Wrap(err, errDomainCreateClients)
}
params := map[string]any{
"entities": client,
"count": len(client),
}
if err := am.callOut(ctx, session, clients.OpCreateClient.String(clients.OperationNames), "", params); err != nil {
return []clients.Client{}, []roles.RoleProvision{}, err
}
return am.svc.CreateClients(ctx, session, client...)
}
@@ -146,10 +132,6 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
return clients.Client{}, errors.Wrap(err, errView)
}
if err := am.callOut(ctx, session, clients.OpViewClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return am.svc.View(ctx, session, id, withRoles)
}
@@ -171,14 +153,6 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
session.SuperAdmin = true
}
params := map[string]any{
"pagemeta": pm,
}
if err := am.callOut(ctx, session, clients.OpListClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
return am.svc.ListClients(ctx, session, pm)
}
@@ -200,15 +174,6 @@ func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session
return clients.ClientsPage{}, err
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, clients.OpListUserClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
return am.svc.ListUserClients(ctx, session, userID, pm)
}
@@ -236,10 +201,6 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
return clients.Client{}, errors.Wrap(err, errUpdate)
}
if err := am.callOut(ctx, session, clients.OpUpdateClient.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
return am.svc.Update(ctx, session, client)
}
@@ -267,10 +228,6 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
return clients.Client{}, errors.Wrap(err, errUpdateTags)
}
if err := am.callOut(ctx, session, clients.OpUpdateClientTags.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
return am.svc.UpdateTags(ctx, session, client)
}
@@ -298,10 +255,6 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut
return clients.Client{}, errors.Wrap(err, errUpdateSecret)
}
if err := am.callOut(ctx, session, clients.OpUpdateClientSecret.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return am.svc.UpdateSecret(ctx, session, id, key)
}
@@ -329,10 +282,6 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
return clients.Client{}, errors.Wrap(err, errEnable)
}
if err := am.callOut(ctx, session, clients.OpEnableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return am.svc.Enable(ctx, session, id)
}
@@ -360,10 +309,6 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
return clients.Client{}, errors.Wrap(err, errDisable)
}
if err := am.callOut(ctx, session, clients.OpDisableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return am.svc.Disable(ctx, session, id)
}
@@ -390,10 +335,6 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses
return errors.Wrap(err, errDelete)
}
if err := am.callOut(ctx, session, clients.OpDeleteClient.String(clients.OperationNames), id, nil); err != nil {
return err
}
return am.svc.Delete(ctx, session, id)
}
@@ -431,14 +372,6 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
return errors.Wrap(err, errGroupSetChildClients)
}
params := map[string]any{
"parent_id": parentGroupID,
}
if err := am.callOut(ctx, session, clients.OpSetParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
@@ -482,14 +415,6 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
return errors.Wrap(err, errGroupRemoveChildClients)
}
params := map[string]any{
"parent_id": th.ParentGroup,
}
if err := am.callOut(ctx, session, clients.OpRemoveParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveParentGroup(ctx, session, id)
}
return nil
@@ -540,24 +465,3 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
}
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ClientType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+185
View File
@@ -0,0 +1,185 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ clients.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc clients.Service
repo clients.Repository
callout callout.Callout
rolemw.RoleManagerCalloutMiddleware
}
func NewCallout(svc clients.Service, repo clients.Repository, callout callout.Callout) (clients.Service, error) {
call, err := rolemw.NewCallout(policies.ClientType, svc, callout)
if err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
repo: repo,
callout: callout,
RoleManagerCalloutMiddleware: call,
}, nil
}
func (cm *calloutMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) {
params := map[string]any{
"entities": client,
"count": len(client),
}
if err := cm.callOut(ctx, session, clients.OpCreateClient.String(clients.OperationNames), "", params); err != nil {
return []clients.Client{}, []roles.RoleProvision{}, err
}
return cm.svc.CreateClients(ctx, session, client...)
}
func (cm *calloutMiddleware) View(ctx context.Context, session authn.Session, id string, withRoles bool) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpViewClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.View(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
params := map[string]any{
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, clients.OpListClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
return cm.svc.ListClients(ctx, session, pm)
}
func (cm *calloutMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, clients.OpListUserClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
return cm.svc.ListUserClients(ctx, session, userID, pm)
}
func (cm *calloutMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpUpdateClient.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.Update(ctx, session, client)
}
func (cm *calloutMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpUpdateClientTags.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.UpdateTags(ctx, session, client)
}
func (cm *calloutMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpUpdateClientSecret.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.UpdateSecret(ctx, session, id, key)
}
func (cm *calloutMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpEnableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.Enable(ctx, session, id)
}
func (cm *calloutMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := cm.callOut(ctx, session, clients.OpDisableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return cm.svc.Disable(ctx, session, id)
}
func (cm *calloutMiddleware) Delete(ctx context.Context, session authn.Session, id string) error {
if err := cm.callOut(ctx, session, clients.OpDeleteClient.String(clients.OperationNames), id, nil); err != nil {
return err
}
return cm.svc.Delete(ctx, session, id)
}
func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
params := map[string]any{
"parent_id": parentGroupID,
}
if err := cm.callOut(ctx, session, clients.OpSetParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
return cm.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
th, err := cm.repo.RetrieveByID(ctx, id)
if err != nil {
return err
}
if th.ParentGroup != "" {
params := map[string]any{
"parent_id": th.ParentGroup,
}
if err := cm.callOut(ctx, session, clients.OpRemoveParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
}
return cm.svc.RemoveParentGroup(ctx, session, id)
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ClientType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := cm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+7 -1
View File
@@ -363,10 +363,16 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, cach
counter, latency := prometheus.MakeMetrics("channels", "api")
svc = middleware.NewMetrics(svc, counter, latency)
svc, err = middleware.NewAuthorization(svc, repo, authz, channels.NewOperationPermissionMap(), channels.NewRolesOperationPermissionMap(), channels.NewExternalOperationPermissionMap(), callout)
svc, err = middleware.NewAuthorization(svc, repo, authz, channels.NewOperationPermissionMap(), channels.NewRolesOperationPermissionMap(), channels.NewExternalOperationPermissionMap())
if err != nil {
return nil, nil, err
}
svc, err = middleware.NewCallout(svc, repo, callout)
if err != nil {
return nil, nil, err
}
svc = middleware.NewLogging(svc, logger)
psvc := pChannels.New(repo, cache, pe, ps, da)
+7 -1
View File
@@ -364,10 +364,16 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, auth
counter, latency := prometheus.MakeMetrics(svcName, "api")
csvc = middleware.NewMetrics(csvc, counter, latency)
csvc, err = middleware.NewAuthorization(policies.ClientType, csvc, authz, repo, clients.NewOperationPermissionMap(), clients.NewRolesOperationPermissionMap(), clients.NewExternalOperationPermissionMap(), callout)
csvc, err = middleware.NewAuthorization(policies.ClientType, csvc, authz, repo, clients.NewOperationPermissionMap(), clients.NewRolesOperationPermissionMap(), clients.NewExternalOperationPermissionMap())
if err != nil {
return nil, nil, err
}
csvc, err = middleware.NewCallout(csvc, repo, callout)
if err != nil {
return nil, nil, err
}
csvc = middleware.NewLogging(csvc, logger)
isvc := pClients.New(repo, cache, pe, ps)
+6 -1
View File
@@ -281,7 +281,12 @@ func newDomainService(ctx context.Context, domainsRepo domainsSvc.Repository, ca
return nil, fmt.Errorf("failed to init domain event store middleware: %w", err)
}
svc, err = dmw.NewAuthorization(policies.DomainType, svc, authz, domains.NewOperationPermissionMap(), domains.NewRolesOperationPermissionMap(), callout)
svc, err = dmw.NewAuthorization(policies.DomainType, svc, authz, domains.NewOperationPermissionMap(), domains.NewRolesOperationPermissionMap())
if err != nil {
return nil, err
}
svc, err = dmw.NewCallout(svc, callout)
if err != nil {
return nil, err
}
+6 -1
View File
@@ -329,7 +329,12 @@ func newService(ctx context.Context, authz smqauthz.Authorization, policy polici
}
svc, err = middleware.NewAuthorization(policies.GroupType, svc, repo, authz, groups.NewOperationPermissionMap(), groups.NewRolesOperationPermissionMap(),
groups.NewExternalOperationPermissionMap(), callout)
groups.NewExternalOperationPermissionMap())
if err != nil {
return nil, nil, err
}
svc, err = middleware.NewCallout(svc, repo, callout)
if err != nil {
return nil, nil, err
}
+6 -117
View File
@@ -5,14 +5,12 @@ package middleware
import (
"context"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
@@ -27,15 +25,14 @@ var _ domains.Service = (*authorizationMiddleware)(nil)
var ErrMemberExist = errors.New("user is already a member of the domain")
type authorizationMiddleware struct {
svc domains.Service
authz smqauthz.Authorization
opp svcutil.OperationPerm
callout callout.Callout
svc domains.Service
authz smqauthz.Authorization
opp svcutil.OperationPerm
rolemw.RoleManagerAuthorizationMiddleware
}
// NewAuthorization adds authorization to the domains service.
func NewAuthorization(entityType string, svc domains.Service, authz smqauthz.Authorization, domainsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission, callout callout.Callout) (domains.Service, error) {
func NewAuthorization(entityType string, svc domains.Service, authz smqauthz.Authorization, domainsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission) (domains.Service, error) {
opp := domains.NewOperationPerm()
if err := opp.AddOperationPermissionMap(domainsOpPerm); err != nil {
return nil, err
@@ -44,7 +41,7 @@ func NewAuthorization(entityType string, svc domains.Service, authz smqauthz.Aut
return nil, err
}
ram, err := rolemw.NewAuthorization(entityType, svc, authz, rolesOpPerm, callout)
ram, err := rolemw.NewAuthorization(entityType, svc, authz, rolesOpPerm)
if err != nil {
return nil, err
}
@@ -52,16 +49,11 @@ func NewAuthorization(entityType string, svc domains.Service, authz smqauthz.Aut
svc: svc,
authz: authz,
opp: opp,
callout: callout,
RoleManagerAuthorizationMiddleware: ram,
}, nil
}
func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) {
if err := am.callOut(ctx, session, domains.OpCreateDomain.String(domains.OperationNames), d.ID, nil); err != nil {
return domains.Domain{}, nil, err
}
return am.svc.CreateDomain(ctx, session, d)
}
@@ -81,14 +73,6 @@ func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session a
return domains.Domain{}, err
}
params := map[string]any{
"with_roles": withRoles,
}
if err := am.callOut(ctx, session, domains.OpRetrieveDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return am.svc.RetrieveDomain(ctx, session, id, withRoles)
}
@@ -103,14 +87,6 @@ func (am *authorizationMiddleware) UpdateDomain(ctx context.Context, session aut
return domains.Domain{}, err
}
params := map[string]any{
"domain_req": d,
}
if err := am.callOut(ctx, session, domains.OpUpdateDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return am.svc.UpdateDomain(ctx, session, id, d)
}
@@ -125,10 +101,6 @@ func (am *authorizationMiddleware) EnableDomain(ctx context.Context, session aut
return domains.Domain{}, err
}
if err := am.callOut(ctx, session, domains.OpEnableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.EnableDomain(ctx, session, id)
}
@@ -143,10 +115,6 @@ func (am *authorizationMiddleware) DisableDomain(ctx context.Context, session au
return domains.Domain{}, err
}
if err := am.callOut(ctx, session, domains.OpDisableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.DisableDomain(ctx, session, id)
}
@@ -163,10 +131,6 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut
return domains.Domain{}, err
}
if err := am.callOut(ctx, session, domains.OpFreezeDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.FreezeDomain(ctx, session, id)
}
@@ -175,14 +139,6 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth
session.SuperAdmin = true
}
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListDomains.String(domains.OperationNames), "", params); err != nil {
return domains.DomainsPage{}, err
}
return am.svc.ListDomains(ctx, session, page)
}
@@ -197,28 +153,10 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a
return domains.Invitation{}, err
}
params := map[string]any{
"invitation": invitation,
}
// While entity here is technically an invitation, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpSendInvitation.String(domains.OperationNames), invitation.DomainID, params); err != nil {
return domains.Invitation{}, err
}
return am.svc.SendInvitation(ctx, session, invitation)
}
func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListInvitations.String(domains.OperationNames), "", params); err != nil {
return domains.InvitationPage{}, err
}
return am.svc.ListInvitations(ctx, session, page)
}
@@ -227,34 +165,14 @@ func (am *authorizationMiddleware) ListDomainInvitations(ctx context.Context, se
return domains.InvitationPage{}, err
}
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListDomainInvitations.String(domains.OperationNames), page.DomainID, params); err != nil {
return domains.InvitationPage{}, err
}
return am.svc.ListDomainInvitations(ctx, session, page)
}
func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (inv domains.Invitation, err error) {
// Similar to sending an invitation, Domain is used as the
// entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpAcceptInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return domains.Invitation{}, err
}
return am.svc.AcceptInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (inv domains.Invitation, err error) {
// Similar to sending and accepting, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpRejectInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return domains.Invitation{}, err
}
func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (domains.Invitation, error) {
return am.svc.RejectInvitation(ctx, session, domainID)
}
@@ -264,14 +182,6 @@ func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session
return err
}
params := map[string]any{
"invitee_user_id": inviteeUserID,
}
if err := am.callOut(ctx, session, domains.OpDeleteInvitation.String(domains.OperationNames), domainID, params); err != nil {
return err
}
return am.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID)
}
@@ -350,24 +260,3 @@ func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm,
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.DomainType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: entityID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+196
View File
@@ -0,0 +1,196 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ domains.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc domains.Service
callout callout.Callout
rolemw.RoleManagerCalloutMiddleware
}
func NewCallout(svc domains.Service, callout callout.Callout) (domains.Service, error) {
call, err := rolemw.NewCallout(policies.ClientType, svc, callout)
if err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
callout: callout,
RoleManagerCalloutMiddleware: call,
}, nil
}
func (cm *calloutMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) {
if err := cm.callOut(ctx, session, domains.OpCreateDomain.String(domains.OperationNames), d.ID, nil); err != nil {
return domains.Domain{}, nil, err
}
return cm.svc.CreateDomain(ctx, session, d)
}
func (cm *calloutMiddleware) RetrieveDomain(ctx context.Context, session authn.Session, id string, withRoles bool) (domains.Domain, error) {
params := map[string]any{
"with_roles": withRoles,
}
if err := cm.callOut(ctx, session, domains.OpRetrieveDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return cm.svc.RetrieveDomain(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) UpdateDomain(ctx context.Context, session authn.Session, id string, d domains.DomainReq) (domains.Domain, error) {
params := map[string]any{
"domain_req": d,
}
if err := cm.callOut(ctx, session, domains.OpUpdateDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return cm.svc.UpdateDomain(ctx, session, id, d)
}
func (cm *calloutMiddleware) EnableDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := cm.callOut(ctx, session, domains.OpEnableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return cm.svc.EnableDomain(ctx, session, id)
}
func (cm *calloutMiddleware) DisableDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := cm.callOut(ctx, session, domains.OpDisableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return cm.svc.DisableDomain(ctx, session, id)
}
func (cm *calloutMiddleware) FreezeDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := cm.callOut(ctx, session, domains.OpFreezeDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return cm.svc.FreezeDomain(ctx, session, id)
}
func (cm *calloutMiddleware) ListDomains(ctx context.Context, session authn.Session, page domains.Page) (domains.DomainsPage, error) {
params := map[string]any{
"page": page,
}
if err := cm.callOut(ctx, session, domains.OpListDomains.String(domains.OperationNames), "", params); err != nil {
return domains.DomainsPage{}, err
}
return cm.svc.ListDomains(ctx, session, page)
}
func (cm *calloutMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (domains.Invitation, error) {
params := map[string]any{
"invitation": invitation,
}
// While entity here is technically an invitation, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, domains.OpSendInvitation.String(domains.OperationNames), invitation.DomainID, params); err != nil {
return domains.Invitation{}, err
}
return cm.svc.SendInvitation(ctx, session, invitation)
}
func (cm *calloutMiddleware) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (domains.InvitationPage, error) {
params := map[string]any{
"page": page,
}
if err := cm.callOut(ctx, session, domains.OpListInvitations.String(domains.OperationNames), "", params); err != nil {
return domains.InvitationPage{}, err
}
return cm.svc.ListInvitations(ctx, session, page)
}
func (cm *calloutMiddleware) ListDomainInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (domains.InvitationPage, error) {
params := map[string]any{
"page": page,
}
if err := cm.callOut(ctx, session, domains.OpListDomainInvitations.String(domains.OperationNames), page.DomainID, params); err != nil {
return domains.InvitationPage{}, err
}
return cm.svc.ListDomainInvitations(ctx, session, page)
}
func (cm *calloutMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (domains.Invitation, error) {
// Similar to sending an invitation, Domain is used as the
// entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, domains.OpAcceptInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return domains.Invitation{}, err
}
return cm.svc.AcceptInvitation(ctx, session, domainID)
}
func (cm *calloutMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (domains.Invitation, error) {
// Similar to sending and accepting, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, domains.OpRejectInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return domains.Invitation{}, err
}
return cm.svc.RejectInvitation(ctx, session, domainID)
}
func (cm *calloutMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) error {
params := map[string]any{
"invitee_user_id": inviteeUserID,
}
if err := cm.callOut(ctx, session, domains.OpDeleteInvitation.String(domains.OperationNames), domainID, params); err != nil {
return err
}
return cm.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID)
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.DomainType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: entityID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := cm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+6 -135
View File
@@ -6,13 +6,11 @@ package middleware
import (
"context"
"fmt"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
@@ -44,12 +42,11 @@ var (
var _ groups.Service = (*authorizationMiddleware)(nil)
type authorizationMiddleware struct {
svc groups.Service
repo groups.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
callout callout.Callout
svc groups.Service
repo groups.Repository
authz smqauthz.Authorization
opp svcutil.OperationPerm
extOpp svcutil.ExternalOperationPerm
rolemw.RoleManagerAuthorizationMiddleware
}
@@ -60,7 +57,6 @@ func NewAuthorization(entityType string,
authz smqauthz.Authorization,
groupsOpPerm, rolesOpPerm map[svcutil.Operation]svcutil.Permission,
extOpPerm map[svcutil.ExternalOperation]svcutil.Permission,
callout callout.Callout,
) (groups.Service, error) {
opp := groups.NewOperationPerm()
if err := opp.AddOperationPermissionMap(groupsOpPerm); err != nil {
@@ -78,7 +74,7 @@ func NewAuthorization(entityType string,
return nil, err
}
ram, err := rolemw.NewAuthorization(entityType, svc, authz, rolesOpPerm, callout)
ram, err := rolemw.NewAuthorization(entityType, svc, authz, rolesOpPerm)
if err != nil {
return nil, err
}
@@ -90,7 +86,6 @@ func NewAuthorization(entityType string,
opp: opp,
extOpp: extOpp,
RoleManagerAuthorizationMiddleware: ram,
callout: callout,
}, nil
}
@@ -132,15 +127,6 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth
}
}
params := map[string]any{
"entities": []groups.Group{g},
"count": 1,
}
if err := am.callOut(ctx, session, groups.OpCreateGroup.String(groups.OperationNames), "", params); err != nil {
return groups.Group{}, []roles.RoleProvision{}, err
}
return am.svc.CreateGroup(ctx, session, g)
}
@@ -169,10 +155,6 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth
return groups.Group{}, errors.Wrap(errUpdate, err)
}
if err := am.callOut(ctx, session, groups.OpUpdateGroup.String(groups.OperationNames), g.ID, nil); err != nil {
return groups.Group{}, err
}
return am.svc.UpdateGroup(ctx, session, g)
}
@@ -200,10 +182,6 @@ func (am *authorizationMiddleware) UpdateGroupTags(ctx context.Context, session
return groups.Group{}, errors.Wrap(errUpdateTags, err)
}
if err := am.callOut(ctx, session, groups.OpUpdateGroupTags.String(groups.OperationNames), group.ID, nil); err != nil {
return groups.Group{}, err
}
return am.svc.UpdateGroupTags(ctx, session, group)
}
@@ -232,10 +210,6 @@ func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.
return groups.Group{}, errors.Wrap(errView, err)
}
if err := am.callOut(ctx, session, groups.OpViewGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return am.svc.ViewGroup(ctx, session, id, withRoles)
}
@@ -268,14 +242,6 @@ func (am *authorizationMiddleware) ListGroups(ctx context.Context, session authn
return groups.Page{}, errors.Wrap(errDomainListGroups, err)
}
params := map[string]any{
"pagemeta": gm,
}
if err := am.callOut(ctx, session, groups.OpListGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
return am.svc.ListGroups(ctx, session, gm)
}
@@ -295,15 +261,6 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a
return groups.Page{}, errors.Wrap(errDomainListGroups, err)
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, groups.OpListUserGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
return am.svc.ListUserGroups(ctx, session, userID, pm)
}
@@ -331,10 +288,6 @@ func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session auth
return groups.Group{}, errors.Wrap(errEnable, err)
}
if err := am.callOut(ctx, session, groups.OpEnableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return am.svc.EnableGroup(ctx, session, id)
}
@@ -362,10 +315,6 @@ func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session aut
return groups.Group{}, errors.Wrap(errDisable, err)
}
if err := am.callOut(ctx, session, groups.OpDisableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return am.svc.DisableGroup(ctx, session, id)
}
@@ -392,10 +341,6 @@ func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session auth
return errors.Wrap(errDelete, err)
}
if err := am.callOut(ctx, session, groups.OpDeleteGroup.String(groups.OperationNames), id, nil); err != nil {
return err
}
return am.svc.DeleteGroup(ctx, session, id)
}
@@ -423,14 +368,6 @@ func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, s
return groups.HierarchyPage{}, errors.Wrap(errViewHierarchy, err)
}
params := map[string]any{
"hierarchy_pagemeta": hm,
}
if err := am.callOut(ctx, session, groups.OpRetrieveGroupHierarchy.String(groups.OperationNames), id, params); err != nil {
return groups.HierarchyPage{}, err
}
return am.svc.RetrieveGroupHierarchy(ctx, session, id, hm)
}
@@ -468,14 +405,6 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a
return errors.Wrap(errParentGroupSetChildGroup, err)
}
params := map[string]any{
"parent_id": parentID,
}
if err := am.callOut(ctx, session, groups.OpAddParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.AddParentGroup(ctx, session, id, parentID)
}
@@ -520,13 +449,6 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
}
params := map[string]any{
"parent_id": group.Parent,
}
if err := am.callOut(ctx, session, groups.OpRemoveParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveParentGroup(ctx, session, id)
}
@@ -566,14 +488,6 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio
}
}
params := map[string]any{
"children_group_ids": childrenGroupIDs,
}
if err := am.callOut(ctx, session, groups.OpAddChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.AddChildrenGroups(ctx, session, id, childrenGroupIDs)
}
@@ -601,14 +515,6 @@ func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, ses
return errors.Wrap(errRemoveChildrenGroups, err)
}
params := map[string]any{
"children_group_ids": childrenGroupIDs,
}
if err := am.callOut(ctx, session, groups.OpRemoveChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveChildrenGroups(ctx, session, id, childrenGroupIDs)
}
@@ -636,10 +542,6 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context,
return err
}
if err := am.callOut(ctx, session, groups.OpRemoveAllChildrenGroups.String(groups.OperationNames), id, nil); err != nil {
return err
}
return am.svc.RemoveAllChildrenGroups(ctx, session, id)
}
@@ -667,16 +569,6 @@ func (am *authorizationMiddleware) ListChildrenGroups(ctx context.Context, sessi
return groups.Page{}, errors.Wrap(errListChildrenGroups, err)
}
params := map[string]any{
"start_level": startLevel,
"end_level": endLevel,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, groups.OpListChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return groups.Page{}, err
}
return am.svc.ListChildrenGroups(ctx, session, id, startLevel, endLevel, pm)
}
@@ -723,24 +615,3 @@ func (am *authorizationMiddleware) extAuthorize(ctx context.Context, extOp svcut
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.GroupType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+243
View File
@@ -0,0 +1,243 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ groups.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc groups.Service
repo groups.Repository
callout callout.Callout
rolemw.RoleManagerCalloutMiddleware
}
func NewCallout(svc groups.Service, repo groups.Repository, callout callout.Callout) (groups.Service, error) {
call, err := rolemw.NewCallout(policies.ClientType, svc, callout)
if err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
repo: repo,
callout: callout,
RoleManagerCalloutMiddleware: call,
}, nil
}
func (cm *calloutMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) {
params := map[string]any{
"entities": []groups.Group{g},
"count": 1,
}
if err := cm.callOut(ctx, session, groups.OpCreateGroup.String(groups.OperationNames), "", params); err != nil {
return groups.Group{}, nil, err
}
return cm.svc.CreateGroup(ctx, session, g)
}
func (cm *calloutMiddleware) UpdateGroup(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
params := map[string]any{
"group": group,
}
if err := cm.callOut(ctx, session, groups.OpUpdateGroup.String(groups.OperationNames), group.ID, params); err != nil {
return groups.Group{}, err
}
return cm.svc.UpdateGroup(ctx, session, group)
}
func (cm *calloutMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
params := map[string]any{
"tags": group.Tags,
}
if err := cm.callOut(ctx, session, groups.OpUpdateGroupTags.String(groups.OperationNames), group.ID, params); err != nil {
return groups.Group{}, err
}
return cm.svc.UpdateGroupTags(ctx, session, group)
}
func (cm *calloutMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (groups.Group, error) {
if err := cm.callOut(ctx, session, groups.OpViewGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return cm.svc.ViewGroup(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) ListGroups(ctx context.Context, session authn.Session, gm groups.PageMeta) (groups.Page, error) {
params := map[string]any{
"pagemeta": gm,
}
if err := cm.callOut(ctx, session, groups.OpListGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
return cm.svc.ListGroups(ctx, session, gm)
}
func (cm *calloutMiddleware) ListUserGroups(ctx context.Context, session authn.Session, userID string, gm groups.PageMeta) (groups.Page, error) {
params := map[string]any{
"user_id": userID,
"pagemeta": gm,
}
if err := cm.callOut(ctx, session, groups.OpListUserGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
return cm.svc.ListUserGroups(ctx, session, userID, gm)
}
func (cm *calloutMiddleware) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if err := cm.callOut(ctx, session, groups.OpEnableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return cm.svc.EnableGroup(ctx, session, id)
}
func (cm *calloutMiddleware) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if err := cm.callOut(ctx, session, groups.OpDisableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
return cm.svc.DisableGroup(ctx, session, id)
}
func (cm *calloutMiddleware) DeleteGroup(ctx context.Context, session authn.Session, id string) error {
if err := cm.callOut(ctx, session, groups.OpDeleteGroup.String(groups.OperationNames), id, nil); err != nil {
return err
}
return cm.svc.DeleteGroup(ctx, session, id)
}
func (cm *calloutMiddleware) RetrieveGroupHierarchy(ctx context.Context, session authn.Session, id string, hm groups.HierarchyPageMeta) (groups.HierarchyPage, error) {
params := map[string]any{
"hierarchy_pagemeta": hm,
}
if err := cm.callOut(ctx, session, groups.OpRetrieveGroupHierarchy.String(groups.OperationNames), id, params); err != nil {
return groups.HierarchyPage{}, err
}
return cm.svc.RetrieveGroupHierarchy(ctx, session, id, hm)
}
func (cm *calloutMiddleware) AddParentGroup(ctx context.Context, session authn.Session, id, parentID string) error {
params := map[string]any{
"parent_id": parentID,
}
if err := cm.callOut(ctx, session, groups.OpAddParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return cm.svc.AddParentGroup(ctx, session, id, parentID)
}
func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
group, err := cm.repo.RetrieveByID(ctx, id)
if err != nil {
return errors.Wrap(svcerr.ErrViewEntity, err)
}
params := map[string]any{
"parent_id": group.Parent,
}
if err := cm.callOut(ctx, session, groups.OpRemoveParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return cm.svc.RemoveParentGroup(ctx, session, id)
}
func (cm *calloutMiddleware) AddChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
params := map[string]any{
"children_group_ids": childrenGroupIDs,
}
if err := cm.callOut(ctx, session, groups.OpAddChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
return cm.svc.AddChildrenGroups(ctx, session, id, childrenGroupIDs)
}
func (cm *calloutMiddleware) RemoveChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
params := map[string]any{
"children_group_ids": childrenGroupIDs,
}
if err := cm.callOut(ctx, session, groups.OpRemoveChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
return cm.svc.RemoveChildrenGroups(ctx, session, id, childrenGroupIDs)
}
func (cm *calloutMiddleware) RemoveAllChildrenGroups(ctx context.Context, session authn.Session, id string) error {
if err := cm.callOut(ctx, session, groups.OpRemoveAllChildrenGroups.String(groups.OperationNames), id, nil); err != nil {
return err
}
return cm.svc.RemoveAllChildrenGroups(ctx, session, id)
}
func (cm *calloutMiddleware) ListChildrenGroups(ctx context.Context, session authn.Session, id string, startLevel, endLevel int64, pm groups.PageMeta) (groups.Page, error) {
params := map[string]any{
"start_level": startLevel,
"end_level": endLevel,
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, groups.OpListChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return groups.Page{}, err
}
return cm.svc.ListChildrenGroups(ctx, session, id, startLevel, endLevel, pm)
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.GroupType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := cm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
@@ -8,7 +8,6 @@ import (
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
@@ -21,12 +20,11 @@ type RoleManagerAuthorizationMiddleware struct {
entityType string
svc roles.RoleManager
authz smqauthz.Authorization
callout callout.Callout
opp svcutil.OperationPerm
}
// NewAuthorization adds authorization for role related methods to the core service.
func NewAuthorization(entityType string, svc roles.RoleManager, authz smqauthz.Authorization, opPerm map[svcutil.Operation]svcutil.Permission, callout callout.Callout) (RoleManagerAuthorizationMiddleware, error) {
func NewAuthorization(entityType string, svc roles.RoleManager, authz smqauthz.Authorization, opPerm map[svcutil.Operation]svcutil.Permission) (RoleManagerAuthorizationMiddleware, error) {
opp := roles.NewOperationPerm()
if err := opp.AddOperationPermissionMap(opPerm); err != nil {
return RoleManagerAuthorizationMiddleware{}, err
@@ -40,7 +38,6 @@ func NewAuthorization(entityType string, svc roles.RoleManager, authz smqauthz.A
svc: svc,
authz: authz,
opp: opp,
callout: callout,
}
if err := ram.validate(); err != nil {
return RoleManagerAuthorizationMiddleware{}, err
@@ -69,15 +66,6 @@ func (ram RoleManagerAuthorizationMiddleware) AddRole(ctx context.Context, sessi
if err := ram.validateMembers(ctx, session, optionalMembers); err != nil {
return roles.RoleProvision{}, err
}
params := map[string]any{
"role_name": roleName,
"optional_actions": optionalActions,
"optional_members": optionalMembers,
"count": 1,
}
if err := ram.callOut(ctx, session, roles.OpAddRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.RoleProvision{}, err
}
return ram.svc.AddRole(ctx, session, entityID, roleName, optionalActions, optionalMembers)
}
@@ -92,12 +80,6 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveRole(ctx context.Context, se
}); err != nil {
return err
}
params := map[string]any{
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRemoveRole.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RemoveRole(ctx, session, entityID, roleID)
}
@@ -112,13 +94,6 @@ func (ram RoleManagerAuthorizationMiddleware) UpdateRoleName(ctx context.Context
}); err != nil {
return roles.Role{}, err
}
params := map[string]any{
"role_id": roleID,
"new_role_name": newRoleName,
}
if err := ram.callOut(ctx, session, roles.OpUpdateRoleName.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return ram.svc.UpdateRoleName(ctx, session, entityID, roleID, newRoleName)
}
@@ -133,12 +108,6 @@ func (ram RoleManagerAuthorizationMiddleware) RetrieveRole(ctx context.Context,
}); err != nil {
return roles.Role{}, err
}
params := map[string]any{
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRetrieveRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return ram.svc.RetrieveRole(ctx, session, entityID, roleID)
}
@@ -153,20 +122,10 @@ func (ram RoleManagerAuthorizationMiddleware) RetrieveAllRoles(ctx context.Conte
}); err != nil {
return roles.RolePage{}, err
}
params := map[string]any{
"limit": limit,
"offset": offset,
}
if err := ram.callOut(ctx, session, roles.OpRetrieveAllRoles.String(roles.OperationNames), entityID, params); err != nil {
return roles.RolePage{}, err
}
return ram.svc.RetrieveAllRoles(ctx, session, entityID, limit, offset)
}
func (ram RoleManagerAuthorizationMiddleware) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) {
if err := ram.callOut(ctx, session, roles.OpListAvailableActions.String(roles.OperationNames), "", nil); err != nil {
return []string{}, err
}
return ram.svc.ListAvailableActions(ctx, session)
}
@@ -182,14 +141,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddActions(ctx context.Context
return []string{}, err
}
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleAddActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return ram.svc.RoleAddActions(ctx, session, entityID, roleID, actions)
}
@@ -205,13 +156,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListActions(ctx context.Contex
return []string{}, err
}
params := map[string]any{
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleListActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return ram.svc.RoleListActions(ctx, session, entityID, roleID)
}
@@ -226,13 +170,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckActionsExists(ctx context
}); err != nil {
return false, err
}
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleCheckActionsExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return ram.svc.RoleCheckActionsExists(ctx, session, entityID, roleID, actions)
}
@@ -247,13 +184,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveActions(ctx context.Cont
}); err != nil {
return err
}
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveActions(ctx, session, entityID, roleID, actions)
}
@@ -269,12 +199,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllActions(ctx context.C
}); err != nil {
return err
}
params := map[string]any{
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveAllActions(ctx, session, entityID, roleID)
}
@@ -293,13 +217,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddMembers(ctx context.Context
if err := ram.validateMembers(ctx, session, members); err != nil {
return []string{}, err
}
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleAddMembers.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return ram.svc.RoleAddMembers(ctx, session, entityID, roleID, members)
}
@@ -314,14 +231,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListMembers(ctx context.Contex
}); err != nil {
return roles.MembersPage{}, err
}
params := map[string]any{
"role_id": roleID,
"limit": limit,
"offset": offset,
}
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersPage{}, err
}
return ram.svc.RoleListMembers(ctx, session, entityID, roleID, limit, offset)
}
@@ -336,13 +245,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckMembersExists(ctx context
}); err != nil {
return false, err
}
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleCheckMembersExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return ram.svc.RoleCheckMembersExists(ctx, session, entityID, roleID, members)
}
@@ -357,12 +259,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllMembers(ctx context.C
}); err != nil {
return err
}
params := map[string]any{
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveAllMembers(ctx, session, entityID, roleID)
}
@@ -377,12 +273,6 @@ func (ram RoleManagerAuthorizationMiddleware) ListEntityMembers(ctx context.Cont
}); err != nil {
return roles.MembersRolePage{}, err
}
params := map[string]any{
"page_query": pageQuery,
}
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersRolePage{}, err
}
return ram.svc.ListEntityMembers(ctx, session, entityID, pageQuery)
}
@@ -397,12 +287,6 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveEntityMembers(ctx context.Co
}); err != nil {
return err
}
params := map[string]any{
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RemoveEntityMembers(ctx, session, entityID, members)
}
@@ -417,13 +301,6 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveMembers(ctx context.Cont
}); err != nil {
return err
}
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveMembers(ctx, session, entityID, roleID, members)
}
@@ -479,23 +356,3 @@ func (ram RoleManagerAuthorizationMiddleware) validateMembers(ctx context.Contex
return nil
}
}
func (ram RoleManagerAuthorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: ram.entityType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
},
Payload: pld,
}
if err := ram.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
+245
View File
@@ -0,0 +1,245 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
)
var _ roles.RoleManager = (*RoleManagerCalloutMiddleware)(nil)
type RoleManagerCalloutMiddleware struct {
entityType string
svc roles.RoleManager
callout callout.Callout
}
func NewCallout(entityType string, svc roles.RoleManager, callout callout.Callout) (RoleManagerCalloutMiddleware, error) {
return RoleManagerCalloutMiddleware{
svc: svc,
callout: callout,
entityType: entityType,
}, nil
}
func (rcm *RoleManagerCalloutMiddleware) AddRole(ctx context.Context, session authn.Session, entityID, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) {
params := map[string]any{
"role_name": roleName,
"optional_actions": optionalActions,
"optional_members": optionalMembers,
"count": 1,
}
if err := rcm.callOut(ctx, session, roles.OpAddRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.RoleProvision{}, err
}
return rcm.svc.AddRole(ctx, session, entityID, roleName, optionalActions, optionalMembers)
}
func (rcm *RoleManagerCalloutMiddleware) RemoveRole(ctx context.Context, session authn.Session, entityID, roleID string) error {
params := map[string]any{
"role_id": roleID,
}
if err := rcm.callOut(ctx, session, roles.OpRemoveRole.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RemoveRole(ctx, session, entityID, roleID)
}
func (rcm *RoleManagerCalloutMiddleware) UpdateRoleName(ctx context.Context, session authn.Session, entityID, roleID, newRoleName string) (roles.Role, error) {
params := map[string]any{
"role_id": roleID,
"new_role_name": newRoleName,
}
if err := rcm.callOut(ctx, session, roles.OpUpdateRoleName.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return rcm.svc.UpdateRoleName(ctx, session, entityID, roleID, newRoleName)
}
func (rcm *RoleManagerCalloutMiddleware) RetrieveRole(ctx context.Context, session authn.Session, entityID, roleID string) (roles.Role, error) {
params := map[string]any{
"role_id": roleID,
}
if err := rcm.callOut(ctx, session, roles.OpRetrieveRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return rcm.svc.RetrieveRole(ctx, session, entityID, roleID)
}
func (rcm *RoleManagerCalloutMiddleware) RetrieveAllRoles(ctx context.Context, session authn.Session, entityID string, limit, offset uint64) (roles.RolePage, error) {
params := map[string]any{
"limit": limit,
"offset": offset,
}
if err := rcm.callOut(ctx, session, roles.OpRetrieveAllRoles.String(roles.OperationNames), entityID, params); err != nil {
return roles.RolePage{}, err
}
return rcm.svc.RetrieveAllRoles(ctx, session, entityID, limit, offset)
}
func (rcm *RoleManagerCalloutMiddleware) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) {
if err := rcm.callOut(ctx, session, roles.OpListAvailableActions.String(roles.OperationNames), "", nil); err != nil {
return []string{}, err
}
return rcm.svc.ListAvailableActions(ctx, session)
}
func (rcm *RoleManagerCalloutMiddleware) RoleAddActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) ([]string, error) {
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := rcm.callOut(ctx, session, roles.OpRoleAddActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return rcm.svc.RoleAddActions(ctx, session, entityID, roleID, actions)
}
func (rcm *RoleManagerCalloutMiddleware) RoleListActions(ctx context.Context, session authn.Session, entityID, roleID string) ([]string, error) {
params := map[string]any{
"role_id": roleID,
}
if err := rcm.callOut(ctx, session, roles.OpRoleListActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return rcm.svc.RoleListActions(ctx, session, entityID, roleID)
}
func (rcm *RoleManagerCalloutMiddleware) RoleCheckActionsExists(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) (bool, error) {
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := rcm.callOut(ctx, session, roles.OpRoleCheckActionsExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return rcm.svc.RoleCheckActionsExists(ctx, session, entityID, roleID, actions)
}
func (rcm *RoleManagerCalloutMiddleware) RoleRemoveActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) error {
params := map[string]any{
"role_id": roleID,
"actions": actions,
}
if err := rcm.callOut(ctx, session, roles.OpRoleRemoveActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RoleRemoveActions(ctx, session, entityID, roleID, actions)
}
func (rcm *RoleManagerCalloutMiddleware) RoleRemoveAllActions(ctx context.Context, session authn.Session, entityID, roleID string) error {
params := map[string]any{
"role_id": roleID,
}
if err := rcm.callOut(ctx, session, roles.OpRoleRemoveAllActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RoleRemoveAllActions(ctx, session, entityID, roleID)
}
func (rcm *RoleManagerCalloutMiddleware) RoleAddMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) ([]string, error) {
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := rcm.callOut(ctx, session, roles.OpRoleAddMembers.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return rcm.svc.RoleAddMembers(ctx, session, entityID, roleID, members)
}
func (rcm *RoleManagerCalloutMiddleware) RoleListMembers(ctx context.Context, session authn.Session, entityID, roleID string, limit, offset uint64) (roles.MembersPage, error) {
params := map[string]any{
"role_id": roleID,
"limit": limit,
"offset": offset,
}
if err := rcm.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersPage{}, err
}
return rcm.svc.RoleListMembers(ctx, session, entityID, roleID, limit, offset)
}
func (rcm *RoleManagerCalloutMiddleware) RoleCheckMembersExists(ctx context.Context, session authn.Session, entityID, roleID string, members []string) (bool, error) {
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := rcm.callOut(ctx, session, roles.OpRoleCheckMembersExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return rcm.svc.RoleCheckMembersExists(ctx, session, entityID, roleID, members)
}
func (rcm *RoleManagerCalloutMiddleware) RoleRemoveAllMembers(ctx context.Context, session authn.Session, entityID, roleID string) error {
params := map[string]any{
"role_id": roleID,
}
if err := rcm.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RoleRemoveAllMembers(ctx, session, entityID, roleID)
}
func (rcm *RoleManagerCalloutMiddleware) ListEntityMembers(ctx context.Context, session authn.Session, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
params := map[string]any{
"page_query": pageQuery,
}
if err := rcm.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersRolePage{}, err
}
return rcm.svc.ListEntityMembers(ctx, session, entityID, pageQuery)
}
func (rcm *RoleManagerCalloutMiddleware) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error {
params := map[string]any{
"members": members,
}
if err := rcm.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RemoveEntityMembers(ctx, session, entityID, members)
}
func (rcm *RoleManagerCalloutMiddleware) RoleRemoveMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) error {
params := map[string]any{
"role_id": roleID,
"members": members,
}
if err := rcm.callOut(ctx, session, roles.OpRoleRemoveMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return rcm.svc.RoleRemoveMembers(ctx, session, entityID, roleID, members)
}
func (rcm *RoleManagerCalloutMiddleware) RemoveMemberFromAllRoles(ctx context.Context, session authn.Session, memberID string) error {
return rcm.svc.RemoveMemberFromAllRoles(ctx, session, memberID)
}
func (rcm *RoleManagerCalloutMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: rcm.entityType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := rcm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}