mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 07:30:25 +00:00
MG-370 - Add fine grained access control to rules engine (#402)
* update go mod file Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix rules endpoint tests Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix yaml file Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix build Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * address comments Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove roles from alarms Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * change approach for schema combaine Signed-off-by: Arvindh <arvindh91@gmail.com> * change approach for schema combaine Signed-off-by: Arvindh <arvindh91@gmail.com> * fix permissions for rules Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix authorization file Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix linter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix linter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> --------- Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> Signed-off-by: Arvindh <arvindh91@gmail.com> Co-authored-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
+1
-1
@@ -44,7 +44,7 @@ func viewRuleEndpoint(s re.Service) endpoint.Endpoint {
|
||||
if err := req.validate(); err != nil {
|
||||
return viewRuleRes{}, err
|
||||
}
|
||||
rule, err := s.ViewRule(ctx, session, req.id)
|
||||
rule, err := s.ViewRule(ctx, session, req.id, req.withRoles)
|
||||
if err != nil {
|
||||
return viewRuleRes{}, err
|
||||
}
|
||||
|
||||
@@ -334,7 +334,7 @@ func TestViewRuleEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id).Return(tc.svcRes, tc.svcErr)
|
||||
svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id, false).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
|
||||
+2
-1
@@ -33,7 +33,8 @@ func (req addRuleReq) validate() error {
|
||||
}
|
||||
|
||||
type viewRuleReq struct {
|
||||
id string
|
||||
id string
|
||||
withRoles bool
|
||||
}
|
||||
|
||||
func (req viewRuleReq) validate() error {
|
||||
|
||||
+12
-1
@@ -16,6 +16,7 @@ import (
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
|
||||
"github.com/go-chi/chi/v5"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -36,6 +37,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
|
||||
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
|
||||
r.Route("/{domainID}", func(r chi.Router) {
|
||||
r.Route("/rules", func(r chi.Router) {
|
||||
d := roleManagerHttp.NewDecoder("ruleID")
|
||||
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
addRuleEndpoint(svc),
|
||||
decodeAddRuleRequest,
|
||||
@@ -50,6 +53,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
|
||||
opts...,
|
||||
), "list_rules").ServeHTTP)
|
||||
|
||||
r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts)
|
||||
|
||||
r.Route("/{ruleID}", func(r chi.Router) {
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewRuleEndpoint(svc),
|
||||
@@ -99,6 +104,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "disable_rule").ServeHTTP)
|
||||
|
||||
roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -123,7 +130,11 @@ func decodeAddRuleRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
func decodeViewRuleRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
id := chi.URLParam(r, ruleIdKey)
|
||||
return viewRuleReq{id: id}, nil
|
||||
withRoles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
return viewRuleReq{id: id, withRoles: withRoles}, nil
|
||||
}
|
||||
|
||||
func decodeUpdateRuleRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package re
|
||||
|
||||
import "github.com/absmach/supermq/pkg/roles"
|
||||
|
||||
const BuiltInRoleAdmin roles.BuiltInRoleName = "admin"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/events"
|
||||
"github.com/absmach/supermq/pkg/events/store"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
@@ -32,6 +33,7 @@ var _ re.Service = (*eventStore)(nil)
|
||||
type eventStore struct {
|
||||
events.Publisher
|
||||
svc re.Service
|
||||
rmEvents.RoleManagerEventStore
|
||||
}
|
||||
|
||||
// NewEventStoreMiddleware returns wrapper around rules service that sends
|
||||
@@ -42,9 +44,12 @@ func NewEventStoreMiddleware(ctx context.Context, svc re.Service, url string) (r
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := rmEvents.NewRoleManagerEventStore("alarms", supermqPrefix, svc, publisher)
|
||||
|
||||
return &eventStore{
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
RoleManagerEventStore: res,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -78,8 +83,8 @@ func (es *eventStore) ListRules(ctx context.Context, session authn.Session, pm r
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
rule, err := es.svc.ViewRule(ctx, session, id)
|
||||
func (es *eventStore) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
|
||||
rule, err := es.svc.ViewRule(ctx, session, id, withRoles)
|
||||
if err != nil {
|
||||
return rule, err
|
||||
}
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/magistrala/re/operations"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,36 +25,47 @@ var (
|
||||
)
|
||||
|
||||
type authorizationMiddleware struct {
|
||||
svc re.Service
|
||||
authz smqauthz.Authorization
|
||||
svc re.Service
|
||||
authz smqauthz.Authorization
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation]
|
||||
rolemgr.RoleManagerAuthorizationMiddleware
|
||||
}
|
||||
|
||||
// AuthorizationMiddleware adds authorization to the re service.
|
||||
func AuthorizationMiddleware(svc re.Service, authz smqauthz.Authorization) (re.Service, error) {
|
||||
func AuthorizationMiddleware(svc re.Service, authz smqauthz.Authorization, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) {
|
||||
if err := entitiesOps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ram, err := rolemgr.NewAuthorization(operations.EntityType, svc, authz, roleOps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
authz: authz,
|
||||
svc: svc,
|
||||
authz: authz,
|
||||
entitiesOps: entitiesOps,
|
||||
RoleManagerAuthorizationMiddleware: ram,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpAddRule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpAddRule, session, policies.DomainType, session.DomainID); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainCreateRules, err)
|
||||
}
|
||||
|
||||
return am.svc.AddRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpViewRule, session); err != nil {
|
||||
func (am *authorizationMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, operations.OpViewRule, session, operations.EntityType, id); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainViewRules, err)
|
||||
}
|
||||
|
||||
return am.svc.ViewRule(ctx, session, id)
|
||||
return am.svc.ViewRule(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpUpdateRule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpUpdateRule, session, operations.EntityType, r.ID); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
|
||||
}
|
||||
|
||||
@@ -60,7 +73,7 @@ func (am *authorizationMiddleware) UpdateRule(ctx context.Context, session authn
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpUpdateRuleTags, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpUpdateRuleTags, session, operations.EntityType, r.ID); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
|
||||
}
|
||||
|
||||
@@ -68,7 +81,7 @@ func (am *authorizationMiddleware) UpdateRuleTags(ctx context.Context, session a
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpUpdateRuleSchedule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpUpdateRuleSchedule, session, operations.EntityType, r.ID); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
|
||||
}
|
||||
|
||||
@@ -76,7 +89,7 @@ func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, sessi
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
|
||||
if err := am.authorize(ctx, re.OpListRules, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpListRules, session, policies.DomainType, session.DomainID); err != nil {
|
||||
return re.Page{}, errors.Wrap(errDomainViewRules, err)
|
||||
}
|
||||
|
||||
@@ -84,7 +97,7 @@ func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
|
||||
if err := am.authorize(ctx, re.OpRemoveRule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpRemoveRule, session, operations.EntityType, id); err != nil {
|
||||
return errors.Wrap(errDomainDeleteRules, err)
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (am *authorizationMiddleware) RemoveRule(ctx context.Context, session authn
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpEnableRule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpEnableRule, session, operations.EntityType, id); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
|
||||
}
|
||||
|
||||
@@ -100,7 +113,7 @@ func (am *authorizationMiddleware) EnableRule(ctx context.Context, session authn
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
if err := am.authorize(ctx, re.OpDisableRule, session); err != nil {
|
||||
if err := am.authorize(ctx, operations.OpDisableRule, session, operations.EntityType, id); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
|
||||
}
|
||||
|
||||
@@ -119,8 +132,8 @@ func (am *authorizationMiddleware) Cancel() error {
|
||||
return am.svc.Cancel()
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session) error {
|
||||
perm, err := re.GetPermission(op)
|
||||
func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session, objType, obj string) error {
|
||||
perm, err := am.entitiesOps.GetPermission(operations.EntityType, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -130,19 +143,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Subject: session.DomainUserID,
|
||||
Object: session.DomainID,
|
||||
ObjectType: policies.DomainType,
|
||||
Permission: perm,
|
||||
Object: obj,
|
||||
ObjectType: objType,
|
||||
Permission: perm.String(),
|
||||
}
|
||||
|
||||
var pat *smqauthz.PATReq
|
||||
if session.PatID != "" {
|
||||
opName := re.OperationName(op)
|
||||
opName := am.entitiesOps.OperationName(operations.EntityType, op)
|
||||
pat = &smqauthz.PATReq{
|
||||
UserID: session.UserID,
|
||||
PatID: session.PatID,
|
||||
EntityID: session.DomainID,
|
||||
EntityType: re.EntityType,
|
||||
EntityType: operations.EntityType,
|
||||
Operation: opName,
|
||||
Domain: session.DomainID,
|
||||
}
|
||||
|
||||
+35
-18
@@ -7,26 +7,43 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgPolicies "github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/magistrala/re/operations"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/callout"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
)
|
||||
|
||||
var _ re.Service = (*calloutMiddleware)(nil)
|
||||
|
||||
type calloutMiddleware struct {
|
||||
svc re.Service
|
||||
callout callout.Callout
|
||||
svc re.Service
|
||||
callout callout.Callout
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation]
|
||||
rolemw.RoleManagerCalloutMiddleware
|
||||
}
|
||||
|
||||
const entityType = "rule"
|
||||
|
||||
func NewCallout(svc re.Service, callout callout.Callout) (re.Service, error) {
|
||||
func NewCallout(svc re.Service, callout callout.Callout, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) {
|
||||
call, err := rolemw.NewCallout(mgPolicies.RulesType, svc, callout, roleOps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := entitiesOps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &calloutMiddleware{
|
||||
svc: svc,
|
||||
callout: callout,
|
||||
svc: svc,
|
||||
callout: callout,
|
||||
entitiesOps: entitiesOps,
|
||||
RoleManagerCalloutMiddleware: call,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -36,23 +53,23 @@ func (cm *calloutMiddleware) AddRule(ctx context.Context, session authn.Session,
|
||||
"count": 1,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpAddRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpAddRule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
return cm.svc.AddRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
func (cm *calloutMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpViewRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpViewRule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
return cm.svc.ViewRule(ctx, session, id)
|
||||
return cm.svc.ViewRule(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
@@ -60,7 +77,7 @@ func (cm *calloutMiddleware) UpdateRule(ctx context.Context, session authn.Sessi
|
||||
"entity_id": r.ID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpUpdateRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpUpdateRule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
@@ -72,7 +89,7 @@ func (cm *calloutMiddleware) UpdateRuleTags(ctx context.Context, session authn.S
|
||||
"entity_id": r.ID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpUpdateRuleTagsStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpUpdateRuleTags, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
@@ -84,7 +101,7 @@ func (cm *calloutMiddleware) UpdateRuleSchedule(ctx context.Context, session aut
|
||||
"entity_id": r.ID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpUpdateRuleScheduleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpUpdateRuleSchedule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
@@ -96,7 +113,7 @@ func (cm *calloutMiddleware) ListRules(ctx context.Context, session authn.Sessio
|
||||
"pagemeta": pm,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpListRulesStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpListRules, params); err != nil {
|
||||
return re.Page{}, err
|
||||
}
|
||||
|
||||
@@ -108,7 +125,7 @@ func (cm *calloutMiddleware) RemoveRule(ctx context.Context, session authn.Sessi
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpRemoveRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpRemoveRule, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -120,7 +137,7 @@ func (cm *calloutMiddleware) EnableRule(ctx context.Context, session authn.Sessi
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpEnableRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpEnableRule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
@@ -132,7 +149,7 @@ func (cm *calloutMiddleware) DisableRule(ctx context.Context, session authn.Sess
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, re.OpDisableRuleStr, params); err != nil {
|
||||
if err := cm.callOut(ctx, session, operations.OpDisableRule, params); err != nil {
|
||||
return re.Rule{}, err
|
||||
}
|
||||
|
||||
@@ -151,7 +168,7 @@ func (cm *calloutMiddleware) Cancel() error {
|
||||
return cm.svc.Cancel()
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op string, pld map[string]any) error {
|
||||
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op permissions.Operation, pld map[string]any) error {
|
||||
var entityID string
|
||||
if id, ok := pld["entity_id"].(string); ok {
|
||||
entityID = id
|
||||
@@ -159,7 +176,7 @@ func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session,
|
||||
|
||||
req := callout.Request{
|
||||
BaseRequest: callout.BaseRequest{
|
||||
Operation: op,
|
||||
Operation: cm.entitiesOps.OperationName(entityType, op),
|
||||
EntityType: entityType,
|
||||
EntityID: entityID,
|
||||
CallerID: session.UserID,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
)
|
||||
|
||||
var _ re.Service = (*loggingMiddleware)(nil)
|
||||
@@ -19,10 +20,15 @@ var _ re.Service = (*loggingMiddleware)(nil)
|
||||
type loggingMiddleware struct {
|
||||
logger *slog.Logger
|
||||
svc re.Service
|
||||
rolemw.RoleManagerLoggingMiddleware
|
||||
}
|
||||
|
||||
func LoggingMiddleware(svc re.Service, logger *slog.Logger) re.Service {
|
||||
return &loggingMiddleware{logger, svc}
|
||||
return &loggingMiddleware{
|
||||
logger: logger,
|
||||
svc: svc,
|
||||
RoleManagerLoggingMiddleware: rolemw.NewLogging("re", svc, logger),
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) {
|
||||
@@ -42,7 +48,7 @@ func (lm *loggingMiddleware) AddRule(ctx context.Context, session authn.Session,
|
||||
return lm.svc.AddRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (res re.Rule, err error) {
|
||||
func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (res re.Rule, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
@@ -59,7 +65,7 @@ func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session
|
||||
}
|
||||
lm.logger.Info("View rule completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ViewRule(ctx, session, id)
|
||||
return lm.svc.ViewRule(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) {
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
)
|
||||
|
||||
type metricsMiddleware struct {
|
||||
counter metrics.Counter
|
||||
latency metrics.Histogram
|
||||
service re.Service
|
||||
rolemw.RoleManagerMetricsMiddleware
|
||||
}
|
||||
|
||||
var _ re.Service = (*metricsMiddleware)(nil)
|
||||
|
||||
func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service re.Service) re.Service {
|
||||
return &metricsMiddleware{
|
||||
counter: counter,
|
||||
latency: latency,
|
||||
service: service,
|
||||
RoleManagerMetricsMiddleware: rolemw.NewMetrics("re", service, counter, latency),
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "add_rule").Add(1)
|
||||
mm.latency.With("method", "add_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.AddRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "view_rule").Add(1)
|
||||
mm.latency.With("method", "view_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.ViewRule(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update_rule").Add(1)
|
||||
mm.latency.With("method", "update_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.UpdateRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update_rule_tags").Add(1)
|
||||
mm.latency.With("method", "update_rule_tags").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.UpdateRuleTags(ctx, session, r)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update_rule_schedule").Add(1)
|
||||
mm.latency.With("method", "update_rule_schedule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.UpdateRuleSchedule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "list_rules").Add(1)
|
||||
mm.latency.With("method", "list_rules").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.ListRules(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "remove_rule").Add(1)
|
||||
mm.latency.With("method", "remove_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.RemoveRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "enable_rule").Add(1)
|
||||
mm.latency.With("method", "enable_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.EnableRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "disable_rule").Add(1)
|
||||
mm.latency.With("method", "disable_rule").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.DisableRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Handle(msg *messaging.Message) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "handle").Add(1)
|
||||
mm.latency.With("method", "handle").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.Handle(msg)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) StartScheduler(ctx context.Context) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "start_scheduler").Add(1)
|
||||
mm.latency.With("method", "start_scheduler").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.service.StartScheduler(ctx)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) Cancel() error {
|
||||
return mm.service.Cancel()
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
smqTracing "github.com/absmach/supermq/pkg/tracing"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type tracingMiddleware struct {
|
||||
tracer trace.Tracer
|
||||
svc re.Service
|
||||
rolemw.RoleManagerTracing
|
||||
}
|
||||
|
||||
var _ re.Service = (*tracingMiddleware)(nil)
|
||||
|
||||
func NewTracingMiddleware(tracer trace.Tracer, svc re.Service) re.Service {
|
||||
return &tracingMiddleware{
|
||||
tracer: tracer,
|
||||
svc: svc,
|
||||
RoleManagerTracing: rolemw.NewTracing("re", svc, tracer),
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "add_rule", trace.WithAttributes(
|
||||
attribute.String("name", r.Name),
|
||||
attribute.String("domain_id", r.DomainID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.AddRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "view_rule", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.ViewRule(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule", trace.WithAttributes(
|
||||
attribute.String("id", r.ID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateRule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_tags", trace.WithAttributes(
|
||||
attribute.String("id", r.ID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateRuleTags(ctx, session, r)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_schedule", trace.WithAttributes(
|
||||
attribute.String("id", r.ID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateRuleSchedule(ctx, session, r)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_rules", trace.WithAttributes(
|
||||
attribute.Int("offset", int(pm.Offset)),
|
||||
attribute.Int("limit", int(pm.Limit)),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.ListRules(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "remove_rule", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.RemoveRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "enable_rule", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.EnableRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "disable_rule", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.DisableRule(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Handle(msg *messaging.Message) error {
|
||||
_, span := smqTracing.StartSpan(context.Background(), tm.tracer, "handle", trace.WithAttributes(
|
||||
attribute.String("channel", msg.Channel),
|
||||
attribute.String("subtopic", msg.Subtopic),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Handle(msg)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) StartScheduler(ctx context.Context) error {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "start_scheduler")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.StartScheduler(ctx)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Cancel() error {
|
||||
return tm.svc.Cancel()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+1500
-12
File diff suppressed because it is too large
Load Diff
@@ -1,68 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package re
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
)
|
||||
|
||||
const EntityType = "rules"
|
||||
|
||||
const (
|
||||
OpAddRule permissions.Operation = iota
|
||||
OpViewRule
|
||||
OpUpdateRule
|
||||
OpUpdateRuleTags
|
||||
OpUpdateRuleSchedule
|
||||
OpListRules
|
||||
OpRemoveRule
|
||||
OpEnableRule
|
||||
OpDisableRule
|
||||
)
|
||||
|
||||
const (
|
||||
OpAddRuleStr = "OpAddRule"
|
||||
OpViewRuleStr = "OpViewRule"
|
||||
OpUpdateRuleStr = "OpUpdateRule"
|
||||
OpUpdateRuleTagsStr = "OpUpdateRuleTags"
|
||||
OpUpdateRuleScheduleStr = "OpUpdateRuleSchedule"
|
||||
OpListRulesStr = "OpListRules"
|
||||
OpRemoveRuleStr = "OpRemoveRule"
|
||||
OpEnableRuleStr = "OpEnableRule"
|
||||
OpDisableRuleStr = "OpDisableRule"
|
||||
)
|
||||
|
||||
func GetPermission(op permissions.Operation) (string, error) {
|
||||
if op < OpAddRule || op > OpDisableRule {
|
||||
return "", errors.New("invalid operation")
|
||||
}
|
||||
return policies.MembershipPermission, nil
|
||||
}
|
||||
|
||||
func OperationName(op permissions.Operation) string {
|
||||
switch op {
|
||||
case OpAddRule:
|
||||
return OpAddRuleStr
|
||||
case OpViewRule:
|
||||
return OpViewRuleStr
|
||||
case OpUpdateRule:
|
||||
return OpUpdateRuleStr
|
||||
case OpUpdateRuleTags:
|
||||
return OpUpdateRuleTagsStr
|
||||
case OpUpdateRuleSchedule:
|
||||
return OpUpdateRuleScheduleStr
|
||||
case OpListRules:
|
||||
return OpListRulesStr
|
||||
case OpRemoveRule:
|
||||
return OpRemoveRuleStr
|
||||
case OpEnableRule:
|
||||
return OpEnableRuleStr
|
||||
case OpDisableRule:
|
||||
return OpDisableRuleStr
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package operations
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
)
|
||||
|
||||
const EntityType = "rule"
|
||||
|
||||
// Rule Operations.
|
||||
const (
|
||||
OpAddRule permissions.Operation = iota
|
||||
OpViewRule
|
||||
OpUpdateRule
|
||||
OpUpdateRuleTags
|
||||
OpUpdateRuleSchedule
|
||||
OpRemoveRule
|
||||
OpListRules
|
||||
OpEnableRule
|
||||
OpDisableRule
|
||||
)
|
||||
|
||||
func OperationDetails() map[permissions.Operation]permissions.OperationDetails {
|
||||
return map[permissions.Operation]permissions.OperationDetails{
|
||||
OpAddRule: {
|
||||
Name: "add",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpViewRule: {
|
||||
Name: "view",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpUpdateRule: {
|
||||
Name: "update",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpUpdateRuleTags: {
|
||||
Name: "update_tags",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpUpdateRuleSchedule: {
|
||||
Name: "update_schedule",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpRemoveRule: {
|
||||
Name: "delete",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpListRules: {
|
||||
Name: "list",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpEnableRule: {
|
||||
Name: "enable",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpDisableRule: {
|
||||
Name: "disable",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
+20
-2
@@ -4,12 +4,20 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
dpostgres "github.com/absmach/supermq/domains/postgres"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
)
|
||||
|
||||
func Migration() *migrate.MemoryMigrationSource {
|
||||
return &migrate.MemoryMigrationSource{
|
||||
func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
|
||||
}
|
||||
rulesMigration := &migrate.MemoryMigrationSource{
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "rules_01",
|
||||
@@ -65,4 +73,14 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rulesMigration.Migrations = append(rulesMigration.Migrations, rolesMigration.Migrations...)
|
||||
|
||||
domainsMigration, err := dpostgres.Migration()
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
|
||||
}
|
||||
rulesMigration.Migrations = append(rulesMigration.Migrations, domainsMigration.Migrations...)
|
||||
|
||||
return rulesMigration, nil
|
||||
}
|
||||
|
||||
+165
-1
@@ -10,19 +10,32 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mgPolicies "github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/absmach/magistrala/re"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
"github.com/absmach/supermq/pkg/postgres"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
)
|
||||
|
||||
const (
|
||||
rolesTableNamePrefix = "rules"
|
||||
entityTableName = "rules"
|
||||
entityIDColumnName = "id"
|
||||
)
|
||||
|
||||
type PostgresRepository struct {
|
||||
DB postgres.Database
|
||||
rolesPostgres.Repository
|
||||
}
|
||||
|
||||
func NewRepository(db postgres.Database) re.Repository {
|
||||
return &PostgresRepository{DB: db}
|
||||
rolesRepo := rolesPostgres.NewRepository(db, mgPolicies.RulesType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
return &PostgresRepository{
|
||||
DB: db,
|
||||
Repository: rolesRepo,
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) {
|
||||
@@ -82,6 +95,157 @@ func (repo *PostgresRepository) ViewRule(ctx context.Context, id string) (re.Rul
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (re.Rule, error) {
|
||||
query := `
|
||||
WITH selected_rule AS (
|
||||
SELECT
|
||||
r.id,
|
||||
r.domain_id
|
||||
FROM
|
||||
rules r
|
||||
WHERE
|
||||
r.id = :id
|
||||
LIMIT 1
|
||||
),
|
||||
selected_rule_roles AS (
|
||||
SELECT
|
||||
rr.entity_id AS rule_id,
|
||||
rrm.member_id AS member_id,
|
||||
rr.id AS role_id,
|
||||
rr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT rra."action") AS actions,
|
||||
'direct' AS access_type,
|
||||
'' AS access_provider_id
|
||||
FROM
|
||||
rules_roles rr
|
||||
JOIN
|
||||
rules_role_members rrm ON rr.id = rrm.role_id
|
||||
JOIN
|
||||
rules_role_actions rra ON rr.id = rra.role_id
|
||||
JOIN
|
||||
selected_rule sr ON sr.id = rr.entity_id
|
||||
AND rrm.member_id = :member_id
|
||||
GROUP BY
|
||||
rr.entity_id, rr.id, rr.name, rrm.member_id
|
||||
),
|
||||
selected_domain_roles AS (
|
||||
SELECT
|
||||
sr.id AS rule_id,
|
||||
drm.member_id AS member_id,
|
||||
dr.id AS role_id,
|
||||
dr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT all_actions."action") AS actions,
|
||||
'domain' AS access_type,
|
||||
dr.entity_id AS access_provider_id
|
||||
FROM
|
||||
domains d
|
||||
JOIN
|
||||
selected_rule sr ON sr.domain_id = d.id
|
||||
JOIN
|
||||
domains_roles dr ON dr.entity_id = d.id
|
||||
JOIN
|
||||
domains_role_members drm ON dr.id = drm.role_id
|
||||
JOIN
|
||||
domains_role_actions dra ON dr.id = dra.role_id
|
||||
JOIN
|
||||
domains_role_actions all_actions ON dr.id = all_actions.role_id
|
||||
WHERE
|
||||
drm.member_id = :member_id
|
||||
AND dra."action" LIKE 'rule%'
|
||||
GROUP BY
|
||||
sr.id, dr.entity_id, dr.id, dr."name", drm.member_id
|
||||
),
|
||||
all_roles AS (
|
||||
SELECT
|
||||
srr.rule_id,
|
||||
srr.member_id,
|
||||
srr.role_id,
|
||||
srr.role_name,
|
||||
srr.actions,
|
||||
srr.access_type,
|
||||
srr.access_provider_id
|
||||
FROM
|
||||
selected_rule_roles srr
|
||||
UNION
|
||||
SELECT
|
||||
sdr.rule_id,
|
||||
sdr.member_id,
|
||||
sdr.role_id,
|
||||
sdr.role_name,
|
||||
sdr.actions,
|
||||
sdr.access_type,
|
||||
sdr.access_provider_id
|
||||
FROM
|
||||
selected_domain_roles sdr
|
||||
),
|
||||
final_roles AS (
|
||||
SELECT
|
||||
ar.rule_id,
|
||||
ar.member_id,
|
||||
jsonb_agg(
|
||||
jsonb_build_object(
|
||||
'role_id', ar.role_id,
|
||||
'role_name', ar.role_name,
|
||||
'actions', ar.actions,
|
||||
'access_type', ar.access_type,
|
||||
'access_provider_id', ar.access_provider_id
|
||||
)
|
||||
) AS roles
|
||||
FROM all_roles ar
|
||||
GROUP BY
|
||||
ar.rule_id, ar.member_id
|
||||
)
|
||||
SELECT
|
||||
r2.id,
|
||||
r2."name",
|
||||
r2.domain_id,
|
||||
r2.tags,
|
||||
r2.metadata,
|
||||
r2.input_channel,
|
||||
r2.input_topic,
|
||||
r2.outputs,
|
||||
r2.status,
|
||||
r2.logic_type,
|
||||
r2.logic_value,
|
||||
r2.time,
|
||||
r2.recurring,
|
||||
r2.recurring_period,
|
||||
r2.start_datetime,
|
||||
r2.created_at,
|
||||
r2.created_by,
|
||||
r2.updated_at,
|
||||
r2.updated_by,
|
||||
fr.member_id,
|
||||
fr.roles
|
||||
FROM rules r2
|
||||
JOIN final_roles fr ON fr.rule_id = r2.id
|
||||
`
|
||||
parameters := map[string]any{
|
||||
"id": id,
|
||||
"member_id": memberID,
|
||||
}
|
||||
row, err := repo.DB.NamedQueryContext(ctx, query, parameters)
|
||||
if err != nil {
|
||||
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbrule := dbRule{}
|
||||
if !row.Next() {
|
||||
return re.Rule{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbrule); err != nil {
|
||||
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
r, err := dbToRule(dbrule)
|
||||
if err != nil {
|
||||
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) UpdateRuleStatus(ctx context.Context, r re.Rule) (re.Rule, error) {
|
||||
q := `UPDATE rules
|
||||
SET status = :status, updated_at = :updated_at, updated_by = :updated_by
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/absmach/magistrala/pkg/schedule"
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
@@ -35,6 +36,8 @@ type dbRule struct {
|
||||
CreatedBy string `db:"created_by"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
UpdatedBy string `db:"updated_by"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
}
|
||||
|
||||
func ruleToDb(r re.Rule) (dbRule, error) {
|
||||
@@ -108,6 +111,13 @@ func dbToRule(dto dbRule) (re.Rule, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var roles []roles.MemberRoleActions
|
||||
if dto.Roles != nil {
|
||||
if err := json.Unmarshal(dto.Roles, &roles); err != nil {
|
||||
return re.Rule{}, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
return re.Rule{
|
||||
ID: dto.ID,
|
||||
Name: dto.Name,
|
||||
@@ -132,6 +142,7 @@ func dbToRule(dto dbRule) (re.Rule, error) {
|
||||
CreatedBy: dto.CreatedBy,
|
||||
UpdatedAt: dto.UpdatedAt,
|
||||
UpdatedBy: dto.UpdatedBy,
|
||||
Roles: roles,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -75,7 +75,11 @@ func TestMain(m *testing.M) {
|
||||
SSLRootCert: "",
|
||||
}
|
||||
|
||||
if db, err = postgres.Setup(dbConfig, *repostgres.Migration()); err != nil {
|
||||
migration, err := repostgres.Migration()
|
||||
if err != nil {
|
||||
log.Fatalf("Could not get migration: %s", err)
|
||||
}
|
||||
if db, err = postgres.Setup(dbConfig, *migration); err != nil {
|
||||
log.Fatalf("Could not setup test DB connection: %s", err)
|
||||
}
|
||||
|
||||
|
||||
+21
-16
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,21 +43,22 @@ var outputRegistry = map[outputs.OutputType]func() Runnable{
|
||||
}
|
||||
|
||||
type Rule struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DomainID string `json:"domain"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
InputChannel string `json:"input_channel"`
|
||||
InputTopic string `json:"input_topic"`
|
||||
Logic Script `json:"logic"`
|
||||
Outputs Outputs `json:"outputs,omitempty"`
|
||||
Schedule schedule.Schedule `json:"schedule,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
UpdatedBy string `json:"updated_by"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DomainID string `json:"domain"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
InputChannel string `json:"input_channel"`
|
||||
InputTopic string `json:"input_topic"`
|
||||
Logic Script `json:"logic"`
|
||||
Outputs Outputs `json:"outputs,omitempty"`
|
||||
Schedule schedule.Schedule `json:"schedule,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
UpdatedBy string `json:"updated_by"`
|
||||
Roles []roles.MemberRoleActions `json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
// EventEncode converts a Rule struct to map[string]any at event producer.
|
||||
@@ -224,7 +226,7 @@ type Page struct {
|
||||
type Service interface {
|
||||
messaging.MessageHandler
|
||||
AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
|
||||
ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error)
|
||||
ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error)
|
||||
UpdateRule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
|
||||
UpdateRuleTags(ctx context.Context, session authn.Session, r Rule) (Rule, error)
|
||||
UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
|
||||
@@ -234,11 +236,13 @@ type Service interface {
|
||||
DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error)
|
||||
|
||||
StartScheduler(ctx context.Context) error
|
||||
roles.RoleManager
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
AddRule(ctx context.Context, r Rule) (Rule, error)
|
||||
ViewRule(ctx context.Context, id string) (Rule, error)
|
||||
RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (Rule, error)
|
||||
UpdateRule(ctx context.Context, r Rule) (Rule, error)
|
||||
UpdateRuleTags(ctx context.Context, r Rule) (Rule, error)
|
||||
UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error)
|
||||
@@ -246,4 +250,5 @@ type Repository interface {
|
||||
UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error)
|
||||
ListRules(ctx context.Context, pm PageMeta) (Page, error)
|
||||
UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error)
|
||||
roles.Repository
|
||||
}
|
||||
|
||||
+57
-14
@@ -11,11 +11,14 @@ import (
|
||||
"github.com/absmach/magistrala/pkg/emailer"
|
||||
pkglog "github.com/absmach/magistrala/pkg/logger"
|
||||
"github.com/absmach/magistrala/pkg/ticker"
|
||||
"github.com/absmach/magistrala/re/operations"
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
var ErrGoroutinesNotAllowed = errors.New("goroutines are not allowed in Go scripts")
|
||||
@@ -30,23 +33,29 @@ type re struct {
|
||||
ticker ticker.Ticker
|
||||
email emailer.Emailer
|
||||
readers grpcReadersV1.ReadersServiceClient
|
||||
roles.ProvisionManageService
|
||||
}
|
||||
|
||||
func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient) Service {
|
||||
return &re{
|
||||
repo: repo,
|
||||
idp: idp,
|
||||
runInfo: runInfo,
|
||||
rePubSub: rePubSub,
|
||||
writersPub: writersPub,
|
||||
alarmsPub: alarmsPub,
|
||||
ticker: tck,
|
||||
email: emailer,
|
||||
readers: readers,
|
||||
func NewService(repo Repository, runInfo chan pkglog.RunInfo, policy policies.Service, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient, availableActions []roles.Action, builtInRoles map[roles.BuiltInRoleName][]roles.Action) (Service, error) {
|
||||
rpms, err := roles.NewProvisionManageService(operations.EntityType, repo, policy, idp, availableActions, builtInRoles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &re{
|
||||
repo: repo,
|
||||
idp: idp,
|
||||
runInfo: runInfo,
|
||||
rePubSub: rePubSub,
|
||||
writersPub: writersPub,
|
||||
alarmsPub: alarmsPub,
|
||||
ticker: tck,
|
||||
email: emailer,
|
||||
readers: readers,
|
||||
ProvisionManageService: rpms,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) {
|
||||
func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (retRule Rule, retErr error) {
|
||||
if r.Logic.Type == GoType && goKeywordRegex.MatchString(r.Logic.Value) {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrMalformedEntity, ErrGoroutinesNotAllowed)
|
||||
}
|
||||
@@ -72,11 +81,45 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule,
|
||||
return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
if errRollBack := re.repo.RemoveRule(ctx, rule.ID); errRollBack != nil {
|
||||
retErr = errors.Wrap(retErr, errors.Wrap(svcerr.ErrRollbackRepo, errRollBack))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
newBuiltInRoleMembers := map[roles.BuiltInRoleName][]roles.Member{
|
||||
BuiltInRoleAdmin: {roles.Member(session.UserID)},
|
||||
}
|
||||
|
||||
optionalPolicies := []policies.Policy{
|
||||
{
|
||||
SubjectType: policies.DomainType,
|
||||
Subject: session.DomainID,
|
||||
Relation: policies.DomainRelation,
|
||||
ObjectType: operations.EntityType,
|
||||
Object: rule.ID,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = re.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{rule.ID}, optionalPolicies, newBuiltInRoleMembers)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrAddPolicies, err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error) {
|
||||
rule, err := re.repo.ViewRule(ctx, id)
|
||||
func (re *re) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error) {
|
||||
var rule Rule
|
||||
var err error
|
||||
switch withRoles {
|
||||
case true:
|
||||
rule, err = re.repo.RetrieveByIDWithRoles(ctx, id, session.UserID)
|
||||
default:
|
||||
rule, err = re.repo.ViewRule(ctx, id)
|
||||
}
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
+197
-32
@@ -26,6 +26,8 @@ import (
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
pubsubmocks "github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
policymocks "github.com/absmach/supermq/pkg/policies/mocks"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@@ -59,27 +61,40 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer) {
|
||||
func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer, *policymocks.Service) {
|
||||
repo := new(mocks.Repository)
|
||||
mockTicker := new(tmocks.Ticker)
|
||||
idProvider := uuid.NewMock()
|
||||
pubsub := pubsubmocks.NewPubSub(t)
|
||||
readersSvc := new(readmocks.ReadersServiceClient)
|
||||
e := new(emocks.Emailer)
|
||||
return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker, e
|
||||
policy := new(policymocks.Service)
|
||||
availableActions := []roles.Action{}
|
||||
builtInRoles := map[roles.BuiltInRoleName][]roles.Action{
|
||||
"admin": availableActions,
|
||||
}
|
||||
svc, err := re.NewService(repo, runInfo, policy, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc, availableActions, builtInRoles)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service: %v", err)
|
||||
}
|
||||
return svc, repo, pubsub, mockTicker, e, policy
|
||||
}
|
||||
|
||||
func TestAddRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo))
|
||||
ruleName := namegen.Generate()
|
||||
now := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
rule re.Rule
|
||||
res re.Rule
|
||||
err error
|
||||
desc string
|
||||
session authn.Session
|
||||
rule re.Rule
|
||||
res re.Rule
|
||||
err error
|
||||
addPoliciesErr error
|
||||
deletePolicies error
|
||||
addRoleErr error
|
||||
deleteErr error
|
||||
}{
|
||||
{
|
||||
desc: "Add rule successfully",
|
||||
@@ -109,7 +124,10 @@ func TestAddRule(t *testing.T) {
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
err: nil,
|
||||
err: nil,
|
||||
addPoliciesErr: nil,
|
||||
addRoleErr: nil,
|
||||
deleteErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed repo",
|
||||
@@ -126,7 +144,11 @@ func TestAddRule(t *testing.T) {
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrCreateEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
addPoliciesErr: nil,
|
||||
deletePolicies: nil,
|
||||
addRoleErr: nil,
|
||||
deleteErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with non-zero StartDateTime",
|
||||
@@ -158,7 +180,136 @@ func TestAddRule(t *testing.T) {
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
err: nil,
|
||||
err: nil,
|
||||
addPoliciesErr: nil,
|
||||
addRoleErr: nil,
|
||||
deleteErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed to add roles and failed to delete policies",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
addRoleErr: svcerr.ErrCreateEntity,
|
||||
deletePolicies: svcerr.ErrRemoveEntity,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed to add policies",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
addPoliciesErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAddPolicies,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed to add policies and failed rollback",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
addPoliciesErr: svcerr.ErrAuthorization,
|
||||
deleteErr: svcerr.ErrRemoveEntity,
|
||||
err: svcerr.ErrRollbackRepo,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed to add roles",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: pkgSch.Schedule{
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
addRoleErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrAddPolicies,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with Go script containing goroutines",
|
||||
@@ -186,6 +337,10 @@ func TestAddRule(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("AddRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
|
||||
policyCall2 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePolicies).Maybe()
|
||||
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
|
||||
repoCall2 := repo.On("Remove", context.Background(), mock.Anything).Return(tc.deleteErr).Maybe()
|
||||
res, err := svc.AddRule(context.Background(), tc.session, tc.rule)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@@ -193,14 +348,18 @@ func TestAddRule(t *testing.T) {
|
||||
assert.Equal(t, tc.rule.Name, res.Name)
|
||||
assert.Equal(t, tc.rule.Schedule, res.Schedule)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall2.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
now := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
@@ -246,7 +405,7 @@ func TestViewRule(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ViewRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
res, err := svc.ViewRule(context.Background(), tc.session, tc.id)
|
||||
res, err := svc.ViewRule(context.Background(), tc.session, tc.id, false)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@@ -259,7 +418,7 @@ func TestViewRule(t *testing.T) {
|
||||
|
||||
func TestUpdateRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
newName := namegen.Generate()
|
||||
now := time.Now().Add(time.Hour)
|
||||
@@ -370,7 +529,7 @@ func TestUpdateRule(t *testing.T) {
|
||||
|
||||
func TestUpdateRuleTags(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -427,7 +586,7 @@ func TestUpdateRuleTags(t *testing.T) {
|
||||
|
||||
func TestUpdateRuleSchedule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
now := time.Now().UTC()
|
||||
future := now.Add(2 * time.Hour)
|
||||
@@ -495,7 +654,7 @@ func TestUpdateRuleSchedule(t *testing.T) {
|
||||
|
||||
func TestListRules(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
numRules := 50
|
||||
now := time.Now().Add(time.Hour)
|
||||
var rules []re.Rule
|
||||
@@ -629,13 +788,14 @@ func TestListRules(t *testing.T) {
|
||||
|
||||
func TestRemoveRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
err error
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
err error
|
||||
deletePoliciesErr error
|
||||
}{
|
||||
{
|
||||
desc: "remove rule successfully",
|
||||
@@ -643,8 +803,9 @@ func TestRemoveRule(t *testing.T) {
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
err: nil,
|
||||
id: ruleID,
|
||||
err: nil,
|
||||
deletePoliciesErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "remove rule with failed repo",
|
||||
@@ -652,25 +813,28 @@ func TestRemoveRule(t *testing.T) {
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
id: ruleID,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
deletePoliciesErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("RemoveRule", mock.Anything, mock.Anything).Return(tc.err)
|
||||
policyCall := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
|
||||
err := svc.RemoveRule(context.Background(), tc.session, tc.id)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
defer repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
now := time.Now()
|
||||
|
||||
@@ -730,7 +894,7 @@ func TestEnableRule(t *testing.T) {
|
||||
|
||||
func TestDisableRule(t *testing.T) {
|
||||
// nolint:dogsled
|
||||
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
|
||||
now := time.Now()
|
||||
|
||||
@@ -789,7 +953,7 @@ func TestDisableRule(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandle(t *testing.T) {
|
||||
svc, repo, pubmocks, _, emailer := newService(t, make(chan pkglog.RunInfo))
|
||||
svc, repo, pubmocks, _, emailer, _ := newService(t, make(chan pkglog.RunInfo))
|
||||
now := time.Now()
|
||||
scheduled := false
|
||||
|
||||
@@ -1461,7 +1625,8 @@ func TestHandle(t *testing.T) {
|
||||
func TestStartScheduler(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Minute)
|
||||
ri := make(chan pkglog.RunInfo)
|
||||
svc, repo, _, ticker, _ := newService(t, ri)
|
||||
// nolint:dogsled
|
||||
svc, repo, _, ticker, _, _ := newService(t, ri)
|
||||
|
||||
ctxCases := []struct {
|
||||
desc string
|
||||
|
||||
Reference in New Issue
Block a user