NOISSUE - Use structured requests in callouts (#3191)

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2025-10-17 13:11:20 +02:00
committed by GitHub
parent 42b2a8c166
commit c230a24b7d
7 changed files with 366 additions and 383 deletions
+57 -44
View File
@@ -6,7 +6,6 @@ package middleware
import (
"context"
"fmt"
"maps"
"time"
"github.com/absmach/supermq/auth"
@@ -130,11 +129,13 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
}
}
}
params := map[string]any{
"entities": chs,
"count": len(chs),
}
if err := am.callOut(ctx, session, channels.OpCreateChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpCreateChannel.String(channels.OperationNames), "", params); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, err
}
@@ -164,12 +165,11 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errView)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, channels.OpViewChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpViewChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.ViewChannel(ctx, session, id, withRoles)
}
@@ -190,12 +190,15 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
if err := am.checkSuperAdmin(ctx, session); err == nil {
session.SuperAdmin = true
}
params := map[string]any{
"pagemeta": pm,
}
if err := am.callOut(ctx, session, channels.OpListChannels.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpListChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return am.svc.ListChannels(ctx, session, pm)
}
@@ -215,13 +218,16 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session
if err := am.checkSuperAdmin(ctx, session); err != nil {
return channels.ChannelsPage{}, errors.Wrap(err, errList)
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, channels.OpListUserChannels.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpListUserChannels.String(channels.OperationNames), "", params); err != nil {
return channels.ChannelsPage{}, err
}
return am.svc.ListUserChannels(ctx, session, userID, pm)
}
@@ -248,12 +254,11 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errUpdate)
}
params := map[string]any{
"entity_id": channel.ID,
}
if err := am.callOut(ctx, session, channels.OpUpdateChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpUpdateChannel.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.UpdateChannel(ctx, session, channel)
}
@@ -280,12 +285,11 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errUpdateTags)
}
params := map[string]any{
"entity_id": channel.ID,
}
if err := am.callOut(ctx, session, channels.OpUpdateChannelTags.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpUpdateChannelTags.String(channels.OperationNames), channel.ID, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.UpdateChannelTags(ctx, session, channel)
}
@@ -312,12 +316,11 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errEnable)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, channels.OpEnableChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpEnableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.EnableChannel(ctx, session, id)
}
@@ -344,12 +347,11 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errDisable)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, channels.OpDisableChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpDisableChannel.String(channels.OperationNames), id, nil); err != nil {
return channels.Channel{}, err
}
return am.svc.DisableChannel(ctx, session, id)
}
@@ -375,10 +377,8 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au
}); err != nil {
return errors.Wrap(err, errDelete)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, channels.OpDeleteChannel.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpDeleteChannel.String(channels.OperationNames), id, nil); err != nil {
return err
}
@@ -435,14 +435,17 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
return errors.Wrap(err, errClientConnectChannels)
}
}
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := am.callOut(ctx, session, channels.OpConnectClient.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpConnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return am.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
@@ -497,14 +500,17 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
return errors.Wrap(err, errClientDisConnectChannels)
}
}
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := am.callOut(ctx, session, channels.OpDisconnectClient.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpDisconnectClient.String(channels.OperationNames), "", params); err != nil {
return err
}
return am.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
@@ -541,13 +547,15 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}); err != nil {
return errors.Wrap(err, errGroupSetChildChannels)
}
params := map[string]any{
"entity_id": id,
"parent_group_id": parentGroupID,
}
if err := am.callOut(ctx, session, channels.OpSetParentGroup.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpSetParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
@@ -589,13 +597,15 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}); err != nil {
return errors.Wrap(err, errGroupRemoveChildChannels)
}
params := map[string]any{
"entity_id": id,
"parent_group_id": ch.ParentGroup,
}
if err := am.callOut(ctx, session, channels.OpRemoveParentGroup.String(channels.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, channels.OpRemoveParentGroup.String(channels.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveParentGroup(ctx, session, id)
}
return nil
@@ -647,18 +657,21 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op string, params map[string]any) error {
pl := map[string]any{
"entity_type": policies.ChannelType,
"subject_type": policies.UserType,
"subject_id": session.UserID,
"domain": session.DomainID,
"time": time.Now().UTC(),
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ChannelType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
maps.Copy(params, pl)
if err := am.callout.Callout(ctx, op, params); err != nil {
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
+30 -53
View File
@@ -5,7 +5,6 @@ package middleware
import (
"context"
"maps"
"time"
"github.com/absmach/supermq/auth"
@@ -116,7 +115,7 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au
"count": len(client),
}
if err := am.callOut(ctx, session, clients.OpCreateClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpCreateClient.String(clients.OperationNames), "", params); err != nil {
return []clients.Client{}, []roles.RoleProvision{}, err
}
@@ -147,11 +146,7 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
return clients.Client{}, errors.Wrap(err, errView)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, clients.OpViewClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpViewClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
@@ -179,7 +174,8 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
params := map[string]any{
"pagemeta": pm,
}
if err := am.callOut(ctx, session, clients.OpListClients.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpListClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
@@ -203,11 +199,13 @@ func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session
if err := am.checkSuperAdmin(ctx, session); err != nil {
return clients.ClientsPage{}, err
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, clients.OpListUserClients.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpListUserClients.String(clients.OperationNames), "", params); err != nil {
return clients.ClientsPage{}, err
}
@@ -238,11 +236,7 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
return clients.Client{}, errors.Wrap(err, errUpdate)
}
params := map[string]any{
"entity_id": client.ID,
}
if err := am.callOut(ctx, session, clients.OpUpdateClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpUpdateClient.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
@@ -273,11 +267,7 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
return clients.Client{}, errors.Wrap(err, errUpdateTags)
}
params := map[string]any{
"entity_id": client.ID,
}
if err := am.callOut(ctx, session, clients.OpUpdateClientTags.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpUpdateClientTags.String(clients.OperationNames), client.ID, nil); err != nil {
return clients.Client{}, err
}
@@ -308,13 +298,10 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut
return clients.Client{}, errors.Wrap(err, errUpdateSecret)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, clients.OpUpdateClientSecret.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpUpdateClientSecret.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
return am.svc.UpdateSecret(ctx, session, id, key)
}
@@ -342,11 +329,7 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
return clients.Client{}, errors.Wrap(err, errEnable)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, clients.OpEnableClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpEnableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
@@ -377,11 +360,7 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
return clients.Client{}, errors.Wrap(err, errDisable)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, clients.OpDisableClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpDisableClient.String(clients.OperationNames), id, nil); err != nil {
return clients.Client{}, err
}
@@ -411,11 +390,7 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses
return errors.Wrap(err, errDelete)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, clients.OpDeleteClient.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpDeleteClient.String(clients.OperationNames), id, nil); err != nil {
return err
}
@@ -457,13 +432,13 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}
params := map[string]any{
"entity_id": id,
"parent_id": parentGroupID,
}
if err := am.callOut(ctx, session, clients.OpSetParentGroup.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpSetParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
@@ -508,11 +483,10 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
params := map[string]any{
"entity_id": id,
"parent_id": th.ParentGroup,
}
if err := am.callOut(ctx, session, clients.OpRemoveParentGroup.String(clients.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, clients.OpRemoveParentGroup.String(clients.OperationNames), id, params); err != nil {
return err
}
@@ -567,18 +541,21 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op string, params map[string]any) error {
pl := map[string]any{
"entity_type": policies.ClientType,
"subject_type": policies.UserType,
"subject_id": session.UserID,
"domain": session.DomainID,
"time": time.Now().UTC(),
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.ClientType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
maps.Copy(params, pl)
if err := am.callout.Callout(ctx, op, params); err != nil {
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
+54 -45
View File
@@ -5,7 +5,6 @@ package middleware
import (
"context"
"maps"
"time"
"github.com/absmach/supermq/auth"
@@ -59,12 +58,10 @@ func NewAuthorization(entityType string, svc domains.Service, authz smqauthz.Aut
}
func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) {
params := map[string]any{
"domain": d.ID,
}
if err := am.callOut(ctx, session, domains.OpCreateDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpCreateDomain.String(domains.OperationNames), d.ID, nil); err != nil {
return domains.Domain{}, nil, err
}
return am.svc.CreateDomain(ctx, session, d)
}
@@ -83,13 +80,15 @@ func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session a
}); err != nil {
return domains.Domain{}, err
}
params := map[string]any{
"domain": id,
"with_roles": withRoles,
}
if err := am.callOut(ctx, session, domains.OpRetrieveDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpRetrieveDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return am.svc.RetrieveDomain(ctx, session, id, withRoles)
}
@@ -103,13 +102,15 @@ func (am *authorizationMiddleware) UpdateDomain(ctx context.Context, session aut
}); err != nil {
return domains.Domain{}, err
}
params := map[string]any{
"domain": id,
"domain_req": d,
}
if err := am.callOut(ctx, session, domains.OpUpdateDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpUpdateDomain.String(domains.OperationNames), id, params); err != nil {
return domains.Domain{}, err
}
return am.svc.UpdateDomain(ctx, session, id, d)
}
@@ -123,12 +124,11 @@ func (am *authorizationMiddleware) EnableDomain(ctx context.Context, session aut
}); err != nil {
return domains.Domain{}, err
}
params := map[string]any{
"domain": id,
}
if err := am.callOut(ctx, session, domains.OpEnableDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpEnableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.EnableDomain(ctx, session, id)
}
@@ -142,12 +142,11 @@ func (am *authorizationMiddleware) DisableDomain(ctx context.Context, session au
}); err != nil {
return domains.Domain{}, err
}
params := map[string]any{
"domain": id,
}
if err := am.callOut(ctx, session, domains.OpDisableDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpDisableDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.DisableDomain(ctx, session, id)
}
@@ -163,12 +162,11 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut
}); err != nil {
return domains.Domain{}, err
}
params := map[string]any{
"domain": id,
}
if err := am.callOut(ctx, session, domains.OpFreezeDomain.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpFreezeDomain.String(domains.OperationNames), id, nil); err != nil {
return domains.Domain{}, err
}
return am.svc.FreezeDomain(ctx, session, id)
}
@@ -176,12 +174,15 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth
if err := am.checkSuperAdmin(ctx, session); err == nil {
session.SuperAdmin = true
}
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListDomains.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpListDomains.String(domains.OperationNames), "", params); err != nil {
return domains.DomainsPage{}, err
}
return am.svc.ListDomains(ctx, session, page)
}
@@ -195,11 +196,14 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
params := map[string]any{
"invitation": invitation,
"domain": invitation.DomainID,
}
if err := am.callOut(ctx, session, domains.OpSendInvitation.String(domains.OperationNames), params); err != nil {
// While entity here is technically an invitation, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpSendInvitation.String(domains.OperationNames), invitation.DomainID, params); err != nil {
return err
}
@@ -210,7 +214,8 @@ func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListInvitations.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpListInvitations.String(domains.OperationNames), "", params); err != nil {
return domains.InvitationPage{}, err
}
@@ -225,7 +230,8 @@ func (am *authorizationMiddleware) ListDomainInvitations(ctx context.Context, se
params := map[string]any{
"page": page,
}
if err := am.callOut(ctx, session, domains.OpListDomainInvitations.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpListDomainInvitations.String(domains.OperationNames), page.DomainID, params); err != nil {
return domains.InvitationPage{}, err
}
@@ -233,22 +239,22 @@ func (am *authorizationMiddleware) ListDomainInvitations(ctx context.Context, se
}
func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (inv domains.Invitation, err error) {
params := map[string]any{
"domain": domainID,
}
if err := am.callOut(ctx, session, domains.OpAcceptInvitation.String(domains.OperationNames), params); err != nil {
// Similar to sending an invitation, Domain is used as the
// entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpAcceptInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return domains.Invitation{}, err
}
return am.svc.AcceptInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
params := map[string]any{
"domain": domainID,
}
if err := am.callOut(ctx, session, domains.OpRejectInvitation.String(domains.OperationNames), params); err != nil {
// Similar to sending and accepting, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := am.callOut(ctx, session, domains.OpRejectInvitation.String(domains.OperationNames), domainID, nil); err != nil {
return err
}
return am.svc.RejectInvitation(ctx, session, domainID)
}
@@ -260,9 +266,9 @@ func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session
params := map[string]any{
"invitee_user_id": inviteeUserID,
"domain": domainID,
}
if err := am.callOut(ctx, session, domains.OpDeleteInvitation.String(domains.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, domains.OpDeleteInvitation.String(domains.OperationNames), domainID, params); err != nil {
return err
}
@@ -345,17 +351,20 @@ func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm,
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op string, params map[string]any) error {
pl := map[string]any{
"entity_type": policies.DomainType,
"subject_type": policies.UserType,
"subject_id": session.UserID,
"time": time.Now().UTC(),
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.DomainType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
Time: time.Now().UTC(),
},
Payload: pld,
}
maps.Copy(params, pl)
if err := am.callout.Callout(ctx, op, params); err != nil {
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
+57 -54
View File
@@ -6,7 +6,6 @@ package middleware
import (
"context"
"fmt"
"maps"
"time"
"github.com/absmach/supermq/auth"
@@ -132,11 +131,13 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errParentGroupSetChildGroup, err)
}
}
params := map[string]any{
"entities": []groups.Group{g},
"count": 1,
}
if err := am.callOut(ctx, session, groups.OpCreateGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpCreateGroup.String(groups.OperationNames), "", params); err != nil {
return groups.Group{}, []roles.RoleProvision{}, err
}
@@ -167,10 +168,8 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth
}); err != nil {
return groups.Group{}, errors.Wrap(errUpdate, err)
}
params := map[string]any{
"entity_id": g.ID,
}
if err := am.callOut(ctx, session, groups.OpUpdateGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpUpdateGroup.String(groups.OperationNames), g.ID, nil); err != nil {
return groups.Group{}, err
}
@@ -200,12 +199,11 @@ func (am *authorizationMiddleware) UpdateGroupTags(ctx context.Context, session
}); err != nil {
return groups.Group{}, errors.Wrap(errUpdateTags, err)
}
params := map[string]any{
"entity_id": group.ID,
}
if err := am.callOut(ctx, session, groups.OpUpdateGroupTags.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpUpdateGroupTags.String(groups.OperationNames), group.ID, nil); err != nil {
return groups.Group{}, err
}
return am.svc.UpdateGroupTags(ctx, session, group)
}
@@ -233,10 +231,8 @@ func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.
}); err != nil {
return groups.Group{}, errors.Wrap(errView, err)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, groups.OpViewGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpViewGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
@@ -275,7 +271,8 @@ func (am *authorizationMiddleware) ListGroups(ctx context.Context, session authn
params := map[string]any{
"pagemeta": gm,
}
if err := am.callOut(ctx, session, groups.OpListGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpListGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
@@ -297,13 +294,16 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a
}); err != nil {
return groups.Page{}, errors.Wrap(errDomainListGroups, err)
}
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, groups.OpListUserGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpListUserGroups.String(groups.OperationNames), "", params); err != nil {
return groups.Page{}, err
}
return am.svc.ListUserGroups(ctx, session, userID, pm)
}
@@ -330,10 +330,8 @@ func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session auth
}); err != nil {
return groups.Group{}, errors.Wrap(errEnable, err)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, groups.OpEnableGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpEnableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
@@ -363,10 +361,8 @@ func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session aut
}); err != nil {
return groups.Group{}, errors.Wrap(errDisable, err)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, groups.OpDisableGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpDisableGroup.String(groups.OperationNames), id, nil); err != nil {
return groups.Group{}, err
}
@@ -395,10 +391,8 @@ func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session auth
}); err != nil {
return errors.Wrap(errDelete, err)
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, groups.OpDeleteGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpDeleteGroup.String(groups.OperationNames), id, nil); err != nil {
return err
}
@@ -428,13 +422,15 @@ func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, s
}); err != nil {
return groups.HierarchyPage{}, errors.Wrap(errViewHierarchy, err)
}
params := map[string]any{
"entity_id": id,
"hierarchy_pagemeta": hm,
}
if err := am.callOut(ctx, session, groups.OpRetrieveGroupHierarchy.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpRetrieveGroupHierarchy.String(groups.OperationNames), id, params); err != nil {
return groups.HierarchyPage{}, err
}
return am.svc.RetrieveGroupHierarchy(ctx, session, id, hm)
}
@@ -471,13 +467,15 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a
}); err != nil {
return errors.Wrap(errParentGroupSetChildGroup, err)
}
params := map[string]any{
"entity_id": id,
"parent_id": parentID,
}
if err := am.callOut(ctx, session, groups.OpAddParentGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpAddParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.AddParentGroup(ctx, session, id, parentID)
}
@@ -521,11 +519,12 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
return errors.Wrap(errParentGroupRemoveChildGroup, err)
}
}
params := map[string]any{
"entity_id": id,
"parent_id": group.Parent,
}
if err := am.callOut(ctx, session, groups.OpRemoveParentGroup.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpRemoveParentGroup.String(groups.OperationNames), id, params); err != nil {
return err
}
return am.svc.RemoveParentGroup(ctx, session, id)
@@ -566,11 +565,12 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio
return errors.Wrap(errChildGroupSetParentGroup, errors.Wrap(fmt.Errorf("child group id: %s", childID), err))
}
}
params := map[string]any{
"entity_id": id,
"children_group_ids": childrenGroupIDs,
}
if err := am.callOut(ctx, session, groups.OpAddChildrenGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpAddChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
@@ -600,11 +600,12 @@ func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, ses
}); err != nil {
return errors.Wrap(errRemoveChildrenGroups, err)
}
params := map[string]any{
"entity_id": id,
"children_group_ids": childrenGroupIDs,
}
if err := am.callOut(ctx, session, groups.OpRemoveChildrenGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpRemoveChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return err
}
@@ -634,10 +635,8 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context,
}); err != nil {
return err
}
params := map[string]any{
"entity_id": id,
}
if err := am.callOut(ctx, session, groups.OpRemoveAllChildrenGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpRemoveAllChildrenGroups.String(groups.OperationNames), id, nil); err != nil {
return err
}
@@ -667,13 +666,14 @@ func (am *authorizationMiddleware) ListChildrenGroups(ctx context.Context, sessi
}); err != nil {
return groups.Page{}, errors.Wrap(errListChildrenGroups, err)
}
params := map[string]any{
"entity_id": id,
"start_level": startLevel,
"end_level": endLevel,
"pagemeta": pm,
}
if err := am.callOut(ctx, session, groups.OpListChildrenGroups.String(groups.OperationNames), params); err != nil {
if err := am.callOut(ctx, session, groups.OpListChildrenGroups.String(groups.OperationNames), id, params); err != nil {
return groups.Page{}, err
}
@@ -724,18 +724,21 @@ func (am *authorizationMiddleware) extAuthorize(ctx context.Context, extOp svcut
return nil
}
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op string, params map[string]any) error {
pl := map[string]any{
"entity_type": policies.GroupType,
"subject_type": policies.UserType,
"subject_id": session.UserID,
"domain": session.DomainID,
"time": time.Now().UTC(),
func (am *authorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: policies.GroupType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
maps.Copy(params, pl)
if err := am.callout.Callout(ctx, op, params); err != nil {
if err := am.callout.Callout(ctx, req); err != nil {
return err
}
+82 -70
View File
@@ -21,6 +21,48 @@ import (
var errFailedToRead = errors.New("failed to read callout response body")
// Can be used in the implementation of
// callout service with structured payload.
type BaseRequest struct {
Operation string `json:"operation,omitempty"`
EntityType string `json:"entity_type,omitempty"`
EntityID string `json:"entity_id,omitempty"`
CallerID string `json:"caller_id,omitempty"`
CallerType string `json:"caller_type,omitempty"`
DomainID string `json:"domain_id,omitempty"`
Time time.Time `json:"time,omitempty"`
}
type Request struct {
BaseRequest
Payload map[string]any `json:"payload,omitempty"`
}
func (r *Request) toURL() (string, error) {
baseBytes, err := json.Marshal(r.BaseRequest)
if err != nil {
return "", err
}
res := map[string]any{}
maps.Copy(res, r.Payload)
if err := json.Unmarshal(baseBytes, &res); err != nil {
return "", err
}
ret := url.Values{}
for k, v := range res {
ret.Set(k, fmt.Sprintf("%v", v))
}
return ret.Encode(), nil
}
// Callout send a request to an external service.
type Callout interface {
Callout(ctx context.Context, req Request) error
}
type Config struct {
URLs []string `env:"URLS" envDefault:"" envSeparator:","`
Method string `env:"METHOD" envDefault:"POST"`
@@ -39,29 +81,17 @@ type callout struct {
allowedOperation map[string]struct{}
}
type CallOutReq struct {
Operation string `json:"operation"`
SubjectID string `json:"subject_id"`
SubjectType string `json:"subject_type"`
Payload map[string]any `json:"payload"`
}
// Callout send request to an external service.
type Callout interface {
Callout(ctx context.Context, perm string, pl map[string]any) error
}
// New creates a new instance of Callout.
func New(cfg Config) (Callout, error) {
httpClient, err := newCalloutClient(cfg.TLSVerification, cfg.Cert, cfg.Key, cfg.CACert, cfg.Timeout)
if err != nil {
return nil, fmt.Errorf("failied to initialize http client: %w", err)
}
if cfg.Method != http.MethodPost && cfg.Method != http.MethodGet {
return nil, fmt.Errorf("unsupported auth callout method: %s", cfg.Method)
}
httpClient, err := newCalloutClient(cfg.TLSVerification, cfg.Cert, cfg.Key, cfg.CACert, cfg.Timeout)
if err != nil {
return nil, fmt.Errorf("failed to initialize http client: %w", err)
}
allowedOperation := make(map[string]struct{})
for _, operation := range cfg.Operations {
allowedOperation[operation] = struct{}{}
@@ -75,9 +105,29 @@ func New(cfg Config) (Callout, error) {
}, nil
}
func newCalloutClient(ctls bool, certPath, keyPath, caPath string, timeout time.Duration) (*http.Client, error) {
func (c *callout) Callout(ctx context.Context, req Request) error {
if len(c.urls) == 0 {
return nil
}
if _, exists := c.allowedOperation[req.Operation]; !exists {
return nil
}
// Make requests sequentially as they appear in the URL
// slice and fail fast as soon as any request fails.
for _, url := range c.urls {
if err := c.makeRequest(ctx, url, req); err != nil {
return err
}
}
return nil
}
func newCalloutClient(skipInsecure bool, certPath, keyPath, caPath string, timeout time.Duration) (*http.Client, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: !ctls,
InsecureSkipVerify: !skipInsecure,
}
if certPath != "" || keyPath != "" {
clientTLSCert, err := server.LoadX509KeyPair(certPath, keyPath)
@@ -103,48 +153,34 @@ func newCalloutClient(ctls bool, certPath, keyPath, caPath string, timeout time.
return httpClient, nil
}
func (c *callout) makeRequest(ctx context.Context, urlStr string, params map[string]any) error {
var req *http.Request
func (c *callout) makeRequest(ctx context.Context, urlStr string, req Request) error {
var r *http.Request
var err error
switch c.method {
case http.MethodGet:
query := url.Values{}
for key, value := range params {
query.Set(key, fmt.Sprintf("%v", value))
var query string
query, err = req.toURL()
if err != nil {
return err
}
req, err = http.NewRequestWithContext(ctx, c.method, urlStr+"?"+query.Encode(), nil)
r, err = http.NewRequestWithContext(ctx, c.method, urlStr+"?"+query, nil)
case http.MethodPost:
payload := make(map[string]any)
maps.Copy(payload, params)
operation, _ := params["operation"].(string)
subjectID, _ := params["subject_id"].(string)
subjectType, _ := params["subject_type"].(string)
delete(payload, "subject_id")
delete(payload, "subject_type")
delete(payload, "operation")
calloutReq := CallOutReq{
Operation: operation,
SubjectID: subjectID,
SubjectType: subjectType,
Payload: payload,
}
data, jsonErr := json.Marshal(calloutReq)
data, jsonErr := json.Marshal(req)
if jsonErr != nil {
return jsonErr
}
req, err = http.NewRequestWithContext(ctx, c.method, urlStr, bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
r, err = http.NewRequestWithContext(ctx, c.method, urlStr, bytes.NewReader(data))
if err == nil {
r.Header.Set("Content-Type", "application/json")
}
}
if err != nil {
return err
}
resp, err := c.httpClient.Do(req)
resp, err := c.httpClient.Do(r)
if err != nil {
return err
}
@@ -160,27 +196,3 @@ func (c *callout) makeRequest(ctx context.Context, urlStr string, params map[str
return nil
}
func (c *callout) Callout(ctx context.Context, op string, pl map[string]any) error {
if len(c.urls) == 0 {
return nil
}
// Check if the operation is in the allowed list
// Otherwise, only call webhook if the operation is in the map
if _, exists := c.allowedOperation[op]; !exists {
return nil
}
pl["operation"] = op
// We iterate through all URLs in sequence
// if any request fails, we return the error immediately
for _, url := range c.urls {
if err := c.makeRequest(ctx, url, pl); err != nil {
return err
}
}
return nil
}
+33 -46
View File
@@ -32,18 +32,22 @@ const (
filePermission = 0o644
)
var pl = map[string]any{
"entity_type": entityType,
"sender": userID,
"domain": domainID,
"time": time.Now().UTC(),
"operation": operation,
var req = callout.Request{
BaseRequest: callout.BaseRequest{
Operation: operation,
EntityType: entityType,
},
Payload: map[string]any{
"sender": userID,
"time": time.Now().UTC(),
"domain": domainID,
},
}
func TestNewCallout(t *testing.T) {
cases := []struct {
desc string
ctls bool
withTLS bool
certPath string
keyPath string
caPath string
@@ -55,7 +59,7 @@ func TestNewCallout(t *testing.T) {
}{
{
desc: "successful callout creation without TLS",
ctls: false,
withTLS: false,
timeout: time.Second,
method: http.MethodPost,
urls: []string{"http://example.com"},
@@ -63,7 +67,7 @@ func TestNewCallout(t *testing.T) {
},
{
desc: "successful callout creation with TLS",
ctls: true,
withTLS: true,
certPath: "client.crt",
keyPath: "client.key",
caPath: "ca.crt",
@@ -74,7 +78,7 @@ func TestNewCallout(t *testing.T) {
},
{
desc: "failed callout creation with invalid cert",
ctls: true,
withTLS: true,
certPath: "invalid.crt",
keyPath: "invalid.key",
caPath: "invalid.ca",
@@ -82,11 +86,11 @@ func TestNewCallout(t *testing.T) {
method: http.MethodPost,
urls: []string{"http://example.com"},
operations: []string{},
err: errors.New("failied to initialize http client: tls: failed to find any PEM data in certificate input"),
err: errors.New("failed to initialize http client: tls: failed to find any PEM data in certificate input"),
},
{
desc: "invalid method",
ctls: false,
withTLS: false,
timeout: time.Second,
method: "INVALID-METHOD",
urls: []string{"http://example.com"},
@@ -97,7 +101,8 @@ func TestNewCallout(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.desc == "successful callout creation with TLS" {
switch tc.desc {
case "successful callout creation with TLS":
generateAndWriteCertificates(t, tc.caPath, tc.certPath, tc.keyPath)
defer func() {
@@ -105,7 +110,7 @@ func TestNewCallout(t *testing.T) {
os.Remove(tc.keyPath)
os.Remove(tc.caPath)
}()
} else if tc.desc == "failed callout creation with invalid cert" {
case "failed callout creation with invalid cert":
writeFile(t, tc.certPath, []byte("invalid cert content"))
writeFile(t, tc.keyPath, []byte("invalid key content"))
writeFile(t, tc.caPath, []byte("invalid ca content"))
@@ -118,7 +123,7 @@ func TestNewCallout(t *testing.T) {
}
client, err := callout.New(callout.Config{
TLSVerification: tc.ctls,
TLSVerification: tc.withTLS,
Cert: tc.certPath,
Key: tc.keyPath,
CACert: tc.caPath,
@@ -321,7 +326,7 @@ func TestCallout_MakeRequest(t *testing.T) {
assert.NoError(t, err)
ctx := tc.contextSetup()
err = cb.Callout(ctx, operation, pl)
err = cb.Callout(ctx, req)
if tc.expectError {
assert.Error(t, err)
@@ -339,43 +344,25 @@ func TestCallout_Operations(t *testing.T) {
cases := []struct {
desc string
operations []string
payload map[string]any
request callout.Request
serverCalled bool
}{
{
desc: "matching operation is called",
operations: []string{operation},
payload: map[string]any{
"entity_type": entityType,
"sender": userID,
"domain": domainID,
"time": time.Now().UTC(),
"operation": operation,
},
desc: "matching operation is called",
operations: []string{operation},
request: req,
serverCalled: true,
},
{
desc: "non-matching operation is not called",
operations: []string{"other_operation"},
payload: map[string]any{
"entity_type": entityType,
"sender": userID,
"domain": domainID,
"time": time.Now().UTC(),
"operation": operation,
},
desc: "non-matching operation is not called",
operations: []string{"other_operation"},
request: req,
serverCalled: false,
},
{
desc: "empty operations list calls always",
operations: []string{},
payload: map[string]any{
"entity_type": entityType,
"sender": userID,
"domain": domainID,
"time": time.Now().UTC(),
"operation": operation,
},
desc: "empty operations list calls always",
operations: []string{},
request: req,
serverCalled: false,
},
}
@@ -401,7 +388,7 @@ func TestCallout_Operations(t *testing.T) {
})
assert.NoError(t, err)
err = cb.Callout(context.Background(), operation, tc.payload)
err = cb.Callout(context.Background(), tc.request)
assert.NoError(t, err)
assert.Equal(t, tc.serverCalled, serverCalled, "Server call status does not match expected")
})
@@ -421,6 +408,6 @@ func TestCallout_NoURLs(t *testing.T) {
})
assert.NoError(t, err)
err = cb.Callout(context.Background(), operation, pl)
err = cb.Callout(context.Background(), req)
assert.NoError(t, err, "No error should be returned when URL list is empty")
}
@@ -5,8 +5,6 @@ package middleware
import (
"context"
"maps"
"time"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
@@ -72,13 +70,12 @@ func (ram RoleManagerAuthorizationMiddleware) AddRole(ctx context.Context, sessi
return roles.RoleProvision{}, err
}
params := map[string]any{
"entity_id": entityID,
"role_name": roleName,
"optional_actions": optionalActions,
"optional_members": optionalMembers,
"count": 1,
}
if err := ram.callOut(ctx, session, roles.OpAddRole.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpAddRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.RoleProvision{}, err
}
return ram.svc.AddRole(ctx, session, entityID, roleName, optionalActions, optionalMembers)
@@ -96,10 +93,9 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveRole(ctx context.Context, se
return err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRemoveRole.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRemoveRole.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RemoveRole(ctx, session, entityID, roleID)
@@ -117,11 +113,10 @@ func (ram RoleManagerAuthorizationMiddleware) UpdateRoleName(ctx context.Context
return roles.Role{}, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"new_role_name": newRoleName,
}
if err := ram.callOut(ctx, session, roles.OpUpdateRoleName.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpUpdateRoleName.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return ram.svc.UpdateRoleName(ctx, session, entityID, roleID, newRoleName)
@@ -139,10 +134,9 @@ func (ram RoleManagerAuthorizationMiddleware) RetrieveRole(ctx context.Context,
return roles.Role{}, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRetrieveRole.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRetrieveRole.String(roles.OperationNames), entityID, params); err != nil {
return roles.Role{}, err
}
return ram.svc.RetrieveRole(ctx, session, entityID, roleID)
@@ -160,19 +154,17 @@ func (ram RoleManagerAuthorizationMiddleware) RetrieveAllRoles(ctx context.Conte
return roles.RolePage{}, err
}
params := map[string]any{
"entity_id": entityID,
"limit": limit,
"offset": offset,
"limit": limit,
"offset": offset,
}
if err := ram.callOut(ctx, session, roles.OpRetrieveAllRoles.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRetrieveAllRoles.String(roles.OperationNames), entityID, params); err != nil {
return roles.RolePage{}, err
}
return ram.svc.RetrieveAllRoles(ctx, session, entityID, limit, offset)
}
func (ram RoleManagerAuthorizationMiddleware) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) {
params := map[string]any{}
if err := ram.callOut(ctx, session, roles.OpListAvailableActions.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpListAvailableActions.String(roles.OperationNames), "", nil); err != nil {
return []string{}, err
}
return ram.svc.ListAvailableActions(ctx, session)
@@ -191,11 +183,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddActions(ctx context.Context
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"actions": actions,
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleAddActions.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleAddActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
@@ -215,10 +206,9 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListActions(ctx context.Contex
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleListActions.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleListActions.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
@@ -237,11 +227,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckActionsExists(ctx context
return false, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"actions": actions,
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleCheckActionsExists.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleCheckActionsExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return ram.svc.RoleCheckActionsExists(ctx, session, entityID, roleID, actions)
@@ -259,11 +248,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveActions(ctx context.Cont
return err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"actions": actions,
"role_id": roleID,
"actions": actions,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveActions.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleRemoveActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
@@ -282,10 +270,9 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllActions(ctx context.C
return err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllActions.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllActions.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveAllActions(ctx, session, entityID, roleID)
@@ -307,11 +294,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddMembers(ctx context.Context
return []string{}, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"members": members,
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleAddMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleAddMembers.String(roles.OperationNames), entityID, params); err != nil {
return []string{}, err
}
return ram.svc.RoleAddMembers(ctx, session, entityID, roleID, members)
@@ -329,12 +315,11 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListMembers(ctx context.Contex
return roles.MembersPage{}, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"limit": limit,
"offset": offset,
"role_id": roleID,
"limit": limit,
"offset": offset,
}
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersPage{}, err
}
return ram.svc.RoleListMembers(ctx, session, entityID, roleID, limit, offset)
@@ -352,11 +337,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckMembersExists(ctx context
return false, err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"members": members,
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleCheckMembersExists.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleCheckMembersExists.String(roles.OperationNames), entityID, params); err != nil {
return false, err
}
return ram.svc.RoleCheckMembersExists(ctx, session, entityID, roleID, members)
@@ -374,10 +358,9 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllMembers(ctx context.C
return err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"role_id": roleID,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveAllMembers(ctx, session, entityID, roleID)
@@ -395,10 +378,9 @@ func (ram RoleManagerAuthorizationMiddleware) ListEntityMembers(ctx context.Cont
return roles.MembersRolePage{}, err
}
params := map[string]any{
"entity_id": entityID,
"page_query": pageQuery,
}
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleListMembers.String(roles.OperationNames), entityID, params); err != nil {
return roles.MembersRolePage{}, err
}
return ram.svc.ListEntityMembers(ctx, session, entityID, pageQuery)
@@ -416,10 +398,9 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveEntityMembers(ctx context.Co
return err
}
params := map[string]any{
"entity_id": entityID,
"members": members,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleRemoveAllMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RemoveEntityMembers(ctx, session, entityID, members)
@@ -437,11 +418,10 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveMembers(ctx context.Cont
return err
}
params := map[string]any{
"entity_id": entityID,
"role_id": roleID,
"members": members,
"role_id": roleID,
"members": members,
}
if err := ram.callOut(ctx, session, roles.OpRoleRemoveMembers.String(roles.OperationNames), params); err != nil {
if err := ram.callOut(ctx, session, roles.OpRoleRemoveMembers.String(roles.OperationNames), entityID, params); err != nil {
return err
}
return ram.svc.RoleRemoveMembers(ctx, session, entityID, roleID, members)
@@ -500,18 +480,20 @@ func (ram RoleManagerAuthorizationMiddleware) validateMembers(ctx context.Contex
}
}
func (ram RoleManagerAuthorizationMiddleware) callOut(ctx context.Context, session authn.Session, op string, params map[string]any) error {
pl := map[string]any{
"entity_type": ram.entityType,
"subject_type": policies.UserType,
"subject_id": session.UserID,
"domain": session.DomainID,
"time": time.Now().UTC(),
func (ram RoleManagerAuthorizationMiddleware) callOut(ctx context.Context, session authn.Session, op, entityID string, pld map[string]any) error {
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
EntityType: ram.entityType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
},
Payload: pld,
}
maps.Copy(params, pl)
if err := ram.callout.Callout(ctx, op, params); err != nil {
if err := ram.callout.Callout(ctx, req); err != nil {
return err
}