MG-370 - Add fine grained access control to rules engine (#402)

* update go mod file

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix rules endpoint tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix yaml file

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix build

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* address comments

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* remove roles from alarms

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* change approach for schema combaine

Signed-off-by: Arvindh <arvindh91@gmail.com>

* change approach for schema combaine

Signed-off-by: Arvindh <arvindh91@gmail.com>

* fix permissions for rules

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix authorization file

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

---------

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
Signed-off-by: Arvindh <arvindh91@gmail.com>
Co-authored-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Steve Munene
2026-03-05 13:42:51 +03:00
committed by GitHub
parent 8e75edc9f5
commit 362a4fc76d
35 changed files with 5431 additions and 252 deletions
+1 -1
View File
@@ -44,7 +44,7 @@ func viewRuleEndpoint(s re.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return viewRuleRes{}, err
}
rule, err := s.ViewRule(ctx, session, req.id)
rule, err := s.ViewRule(ctx, session, req.id, req.withRoles)
if err != nil {
return viewRuleRes{}, err
}
+1 -1
View File
@@ -334,7 +334,7 @@ func TestViewRuleEndpoint(t *testing.T) {
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id).Return(tc.svcRes, tc.svcErr)
svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id, false).Return(tc.svcRes, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
+2 -1
View File
@@ -33,7 +33,8 @@ func (req addRuleReq) validate() error {
}
type viewRuleReq struct {
id string
id string
withRoles bool
}
func (req viewRuleReq) validate() error {
+12 -1
View File
@@ -16,6 +16,7 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -36,6 +37,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/{domainID}", func(r chi.Router) {
r.Route("/rules", func(r chi.Router) {
d := roleManagerHttp.NewDecoder("ruleID")
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
addRuleEndpoint(svc),
decodeAddRuleRequest,
@@ -50,6 +53,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
opts...,
), "list_rules").ServeHTTP)
r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts)
r.Route("/{ruleID}", func(r chi.Router) {
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
viewRuleEndpoint(svc),
@@ -99,6 +104,8 @@ func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, l
api.EncodeResponse,
opts...,
), "disable_rule").ServeHTTP)
roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts)
})
})
})
@@ -123,7 +130,11 @@ func decodeAddRuleRequest(_ context.Context, r *http.Request) (any, error) {
func decodeViewRuleRequest(_ context.Context, r *http.Request) (any, error) {
id := chi.URLParam(r, ruleIdKey)
return viewRuleReq{id: id}, nil
withRoles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
return viewRuleReq{id: id, withRoles: withRoles}, nil
}
func decodeUpdateRuleRequest(_ context.Context, r *http.Request) (any, error) {
+8
View File
@@ -0,0 +1,8 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package re
import "github.com/absmach/supermq/pkg/roles"
const BuiltInRoleAdmin roles.BuiltInRoleName = "admin"
+9 -4
View File
@@ -11,6 +11,7 @@ import (
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/events/store"
"github.com/absmach/supermq/pkg/messaging"
rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events"
"github.com/go-chi/chi/v5/middleware"
)
@@ -32,6 +33,7 @@ var _ re.Service = (*eventStore)(nil)
type eventStore struct {
events.Publisher
svc re.Service
rmEvents.RoleManagerEventStore
}
// NewEventStoreMiddleware returns wrapper around rules service that sends
@@ -42,9 +44,12 @@ func NewEventStoreMiddleware(ctx context.Context, svc re.Service, url string) (r
return nil, err
}
res := rmEvents.NewRoleManagerEventStore("alarms", supermqPrefix, svc, publisher)
return &eventStore{
svc: svc,
Publisher: publisher,
svc: svc,
Publisher: publisher,
RoleManagerEventStore: res,
}, nil
}
@@ -78,8 +83,8 @@ func (es *eventStore) ListRules(ctx context.Context, session authn.Session, pm r
return page, nil
}
func (es *eventStore) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
rule, err := es.svc.ViewRule(ctx, session, id)
func (es *eventStore) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
rule, err := es.svc.ViewRule(ctx, session, id, withRoles)
if err != nil {
return rule, err
}
+36 -23
View File
@@ -7,12 +7,14 @@ import (
"context"
"github.com/absmach/magistrala/re"
"github.com/absmach/magistrala/re/operations"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/permissions"
"github.com/absmach/supermq/pkg/policies"
rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var (
@@ -23,36 +25,47 @@ var (
)
type authorizationMiddleware struct {
svc re.Service
authz smqauthz.Authorization
svc re.Service
authz smqauthz.Authorization
entitiesOps permissions.EntitiesOperations[permissions.Operation]
rolemgr.RoleManagerAuthorizationMiddleware
}
// AuthorizationMiddleware adds authorization to the re service.
func AuthorizationMiddleware(svc re.Service, authz smqauthz.Authorization) (re.Service, error) {
func AuthorizationMiddleware(svc re.Service, authz smqauthz.Authorization, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) {
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
ram, err := rolemgr.NewAuthorization(operations.EntityType, svc, authz, roleOps)
if err != nil {
return nil, err
}
return &authorizationMiddleware{
svc: svc,
authz: authz,
svc: svc,
authz: authz,
entitiesOps: entitiesOps,
RoleManagerAuthorizationMiddleware: ram,
}, nil
}
func (am *authorizationMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
if err := am.authorize(ctx, re.OpAddRule, session); err != nil {
if err := am.authorize(ctx, operations.OpAddRule, session, policies.DomainType, session.DomainID); err != nil {
return re.Rule{}, errors.Wrap(errDomainCreateRules, err)
}
return am.svc.AddRule(ctx, session, r)
}
func (am *authorizationMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
if err := am.authorize(ctx, re.OpViewRule, session); err != nil {
func (am *authorizationMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
if err := am.authorize(ctx, operations.OpViewRule, session, operations.EntityType, id); err != nil {
return re.Rule{}, errors.Wrap(errDomainViewRules, err)
}
return am.svc.ViewRule(ctx, session, id)
return am.svc.ViewRule(ctx, session, id, withRoles)
}
func (am *authorizationMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
if err := am.authorize(ctx, re.OpUpdateRule, session); err != nil {
if err := am.authorize(ctx, operations.OpUpdateRule, session, operations.EntityType, r.ID); err != nil {
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
}
@@ -60,7 +73,7 @@ func (am *authorizationMiddleware) UpdateRule(ctx context.Context, session authn
}
func (am *authorizationMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
if err := am.authorize(ctx, re.OpUpdateRuleTags, session); err != nil {
if err := am.authorize(ctx, operations.OpUpdateRuleTags, session, operations.EntityType, r.ID); err != nil {
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
}
@@ -68,7 +81,7 @@ func (am *authorizationMiddleware) UpdateRuleTags(ctx context.Context, session a
}
func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
if err := am.authorize(ctx, re.OpUpdateRuleSchedule, session); err != nil {
if err := am.authorize(ctx, operations.OpUpdateRuleSchedule, session, operations.EntityType, r.ID); err != nil {
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
}
@@ -76,7 +89,7 @@ func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, sessi
}
func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
if err := am.authorize(ctx, re.OpListRules, session); err != nil {
if err := am.authorize(ctx, operations.OpListRules, session, policies.DomainType, session.DomainID); err != nil {
return re.Page{}, errors.Wrap(errDomainViewRules, err)
}
@@ -84,7 +97,7 @@ func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.
}
func (am *authorizationMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, re.OpRemoveRule, session); err != nil {
if err := am.authorize(ctx, operations.OpRemoveRule, session, operations.EntityType, id); err != nil {
return errors.Wrap(errDomainDeleteRules, err)
}
@@ -92,7 +105,7 @@ func (am *authorizationMiddleware) RemoveRule(ctx context.Context, session authn
}
func (am *authorizationMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
if err := am.authorize(ctx, re.OpEnableRule, session); err != nil {
if err := am.authorize(ctx, operations.OpEnableRule, session, operations.EntityType, id); err != nil {
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
}
@@ -100,7 +113,7 @@ func (am *authorizationMiddleware) EnableRule(ctx context.Context, session authn
}
func (am *authorizationMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
if err := am.authorize(ctx, re.OpDisableRule, session); err != nil {
if err := am.authorize(ctx, operations.OpDisableRule, session, operations.EntityType, id); err != nil {
return re.Rule{}, errors.Wrap(errDomainUpdateRules, err)
}
@@ -119,8 +132,8 @@ func (am *authorizationMiddleware) Cancel() error {
return am.svc.Cancel()
}
func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session) error {
perm, err := re.GetPermission(op)
func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session, objType, obj string) error {
perm, err := am.entitiesOps.GetPermission(operations.EntityType, op)
if err != nil {
return err
}
@@ -130,19 +143,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.DomainUserID,
Object: session.DomainID,
ObjectType: policies.DomainType,
Permission: perm,
Object: obj,
ObjectType: objType,
Permission: perm.String(),
}
var pat *smqauthz.PATReq
if session.PatID != "" {
opName := re.OperationName(op)
opName := am.entitiesOps.OperationName(operations.EntityType, op)
pat = &smqauthz.PATReq{
UserID: session.UserID,
PatID: session.PatID,
EntityID: session.DomainID,
EntityType: re.EntityType,
EntityType: operations.EntityType,
Operation: opName,
Domain: session.DomainID,
}
+35 -18
View File
@@ -7,26 +7,43 @@ import (
"context"
"time"
mgPolicies "github.com/absmach/magistrala/pkg/policies"
"github.com/absmach/magistrala/re"
"github.com/absmach/magistrala/re/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/permissions"
"github.com/absmach/supermq/pkg/policies"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ re.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc re.Service
callout callout.Callout
svc re.Service
callout callout.Callout
entitiesOps permissions.EntitiesOperations[permissions.Operation]
rolemw.RoleManagerCalloutMiddleware
}
const entityType = "rule"
func NewCallout(svc re.Service, callout callout.Callout) (re.Service, error) {
func NewCallout(svc re.Service, callout callout.Callout, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) {
call, err := rolemw.NewCallout(mgPolicies.RulesType, svc, callout, roleOps)
if err != nil {
return nil, err
}
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
callout: callout,
svc: svc,
callout: callout,
entitiesOps: entitiesOps,
RoleManagerCalloutMiddleware: call,
}, nil
}
@@ -36,23 +53,23 @@ func (cm *calloutMiddleware) AddRule(ctx context.Context, session authn.Session,
"count": 1,
}
if err := cm.callOut(ctx, session, re.OpAddRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpAddRule, params); err != nil {
return re.Rule{}, err
}
return cm.svc.AddRule(ctx, session, r)
}
func (cm *calloutMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
func (cm *calloutMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, re.OpViewRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpViewRule, params); err != nil {
return re.Rule{}, err
}
return cm.svc.ViewRule(ctx, session, id)
return cm.svc.ViewRule(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
@@ -60,7 +77,7 @@ func (cm *calloutMiddleware) UpdateRule(ctx context.Context, session authn.Sessi
"entity_id": r.ID,
}
if err := cm.callOut(ctx, session, re.OpUpdateRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpUpdateRule, params); err != nil {
return re.Rule{}, err
}
@@ -72,7 +89,7 @@ func (cm *calloutMiddleware) UpdateRuleTags(ctx context.Context, session authn.S
"entity_id": r.ID,
}
if err := cm.callOut(ctx, session, re.OpUpdateRuleTagsStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpUpdateRuleTags, params); err != nil {
return re.Rule{}, err
}
@@ -84,7 +101,7 @@ func (cm *calloutMiddleware) UpdateRuleSchedule(ctx context.Context, session aut
"entity_id": r.ID,
}
if err := cm.callOut(ctx, session, re.OpUpdateRuleScheduleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpUpdateRuleSchedule, params); err != nil {
return re.Rule{}, err
}
@@ -96,7 +113,7 @@ func (cm *calloutMiddleware) ListRules(ctx context.Context, session authn.Sessio
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, re.OpListRulesStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpListRules, params); err != nil {
return re.Page{}, err
}
@@ -108,7 +125,7 @@ func (cm *calloutMiddleware) RemoveRule(ctx context.Context, session authn.Sessi
"entity_id": id,
}
if err := cm.callOut(ctx, session, re.OpRemoveRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpRemoveRule, params); err != nil {
return err
}
@@ -120,7 +137,7 @@ func (cm *calloutMiddleware) EnableRule(ctx context.Context, session authn.Sessi
"entity_id": id,
}
if err := cm.callOut(ctx, session, re.OpEnableRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpEnableRule, params); err != nil {
return re.Rule{}, err
}
@@ -132,7 +149,7 @@ func (cm *calloutMiddleware) DisableRule(ctx context.Context, session authn.Sess
"entity_id": id,
}
if err := cm.callOut(ctx, session, re.OpDisableRuleStr, params); err != nil {
if err := cm.callOut(ctx, session, operations.OpDisableRule, params); err != nil {
return re.Rule{}, err
}
@@ -151,7 +168,7 @@ func (cm *calloutMiddleware) Cancel() error {
return cm.svc.Cancel()
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op string, pld map[string]any) error {
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op permissions.Operation, pld map[string]any) error {
var entityID string
if id, ok := pld["entity_id"].(string); ok {
entityID = id
@@ -159,7 +176,7 @@ func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session,
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: op,
Operation: cm.entitiesOps.OperationName(entityType, op),
EntityType: entityType,
EntityID: entityID,
CallerID: session.UserID,
+9 -3
View File
@@ -12,6 +12,7 @@ import (
"github.com/absmach/magistrala/re"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/messaging"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
)
var _ re.Service = (*loggingMiddleware)(nil)
@@ -19,10 +20,15 @@ var _ re.Service = (*loggingMiddleware)(nil)
type loggingMiddleware struct {
logger *slog.Logger
svc re.Service
rolemw.RoleManagerLoggingMiddleware
}
func LoggingMiddleware(svc re.Service, logger *slog.Logger) re.Service {
return &loggingMiddleware{logger, svc}
return &loggingMiddleware{
logger: logger,
svc: svc,
RoleManagerLoggingMiddleware: rolemw.NewLogging("re", svc, logger),
}
}
func (lm *loggingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) {
@@ -42,7 +48,7 @@ func (lm *loggingMiddleware) AddRule(ctx context.Context, session authn.Session,
return lm.svc.AddRule(ctx, session, r)
}
func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string) (res re.Rule, err error) {
func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (res re.Rule, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -59,7 +65,7 @@ func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session
}
lm.logger.Info("View rule completed successfully", args...)
}(time.Now())
return lm.svc.ViewRule(ctx, session, id)
return lm.svc.ViewRule(ctx, session, id, withRoles)
}
func (lm *loggingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) {
+136
View File
@@ -0,0 +1,136 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/magistrala/re"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/messaging"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
"github.com/go-kit/kit/metrics"
)
type metricsMiddleware struct {
counter metrics.Counter
latency metrics.Histogram
service re.Service
rolemw.RoleManagerMetricsMiddleware
}
var _ re.Service = (*metricsMiddleware)(nil)
func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service re.Service) re.Service {
return &metricsMiddleware{
counter: counter,
latency: latency,
service: service,
RoleManagerMetricsMiddleware: rolemw.NewMetrics("re", service, counter, latency),
}
}
func (mm *metricsMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "add_rule").Add(1)
mm.latency.With("method", "add_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.AddRule(ctx, session, r)
}
func (mm *metricsMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "view_rule").Add(1)
mm.latency.With("method", "view_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.ViewRule(ctx, session, id, withRoles)
}
func (mm *metricsMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "update_rule").Add(1)
mm.latency.With("method", "update_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.UpdateRule(ctx, session, r)
}
func (mm *metricsMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "update_rule_tags").Add(1)
mm.latency.With("method", "update_rule_tags").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.UpdateRuleTags(ctx, session, r)
}
func (mm *metricsMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "update_rule_schedule").Add(1)
mm.latency.With("method", "update_rule_schedule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.UpdateRuleSchedule(ctx, session, r)
}
func (mm *metricsMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_rules").Add(1)
mm.latency.With("method", "list_rules").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.ListRules(ctx, session, pm)
}
func (mm *metricsMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
defer func(begin time.Time) {
mm.counter.With("method", "remove_rule").Add(1)
mm.latency.With("method", "remove_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.RemoveRule(ctx, session, id)
}
func (mm *metricsMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "enable_rule").Add(1)
mm.latency.With("method", "enable_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.EnableRule(ctx, session, id)
}
func (mm *metricsMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
defer func(begin time.Time) {
mm.counter.With("method", "disable_rule").Add(1)
mm.latency.With("method", "disable_rule").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.DisableRule(ctx, session, id)
}
func (mm *metricsMiddleware) Handle(msg *messaging.Message) error {
defer func(begin time.Time) {
mm.counter.With("method", "handle").Add(1)
mm.latency.With("method", "handle").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.Handle(msg)
}
func (mm *metricsMiddleware) StartScheduler(ctx context.Context) error {
defer func(begin time.Time) {
mm.counter.With("method", "start_scheduler").Add(1)
mm.latency.With("method", "start_scheduler").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.service.StartScheduler(ctx)
}
func (mm *metricsMiddleware) Cancel() error {
return mm.service.Cancel()
}
+136
View File
@@ -0,0 +1,136 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"github.com/absmach/magistrala/re"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/messaging"
rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
smqTracing "github.com/absmach/supermq/pkg/tracing"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
type tracingMiddleware struct {
tracer trace.Tracer
svc re.Service
rolemw.RoleManagerTracing
}
var _ re.Service = (*tracingMiddleware)(nil)
func NewTracingMiddleware(tracer trace.Tracer, svc re.Service) re.Service {
return &tracingMiddleware{
tracer: tracer,
svc: svc,
RoleManagerTracing: rolemw.NewTracing("re", svc, tracer),
}
}
func (tm *tracingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "add_rule", trace.WithAttributes(
attribute.String("name", r.Name),
attribute.String("domain_id", r.DomainID),
))
defer span.End()
return tm.svc.AddRule(ctx, session, r)
}
func (tm *tracingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "view_rule", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.ViewRule(ctx, session, id, withRoles)
}
func (tm *tracingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule", trace.WithAttributes(
attribute.String("id", r.ID),
))
defer span.End()
return tm.svc.UpdateRule(ctx, session, r)
}
func (tm *tracingMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_tags", trace.WithAttributes(
attribute.String("id", r.ID),
))
defer span.End()
return tm.svc.UpdateRuleTags(ctx, session, r)
}
func (tm *tracingMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_schedule", trace.WithAttributes(
attribute.String("id", r.ID),
))
defer span.End()
return tm.svc.UpdateRuleSchedule(ctx, session, r)
}
func (tm *tracingMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_rules", trace.WithAttributes(
attribute.Int("offset", int(pm.Offset)),
attribute.Int("limit", int(pm.Limit)),
))
defer span.End()
return tm.svc.ListRules(ctx, session, pm)
}
func (tm *tracingMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "remove_rule", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.RemoveRule(ctx, session, id)
}
func (tm *tracingMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "enable_rule", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.EnableRule(ctx, session, id)
}
func (tm *tracingMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "disable_rule", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.DisableRule(ctx, session, id)
}
func (tm *tracingMiddleware) Handle(msg *messaging.Message) error {
_, span := smqTracing.StartSpan(context.Background(), tm.tracer, "handle", trace.WithAttributes(
attribute.String("channel", msg.Channel),
attribute.String("subtopic", msg.Subtopic),
))
defer span.End()
return tm.svc.Handle(msg)
}
func (tm *tracingMiddleware) StartScheduler(ctx context.Context) error {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "start_scheduler")
defer span.End()
return tm.svc.StartScheduler(ctx)
}
func (tm *tracingMiddleware) Cancel() error {
return tm.svc.Cancel()
}
File diff suppressed because it is too large Load Diff
+1500 -12
View File
File diff suppressed because it is too large Load Diff
-68
View File
@@ -1,68 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package re
import (
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/permissions"
"github.com/absmach/supermq/pkg/policies"
)
const EntityType = "rules"
const (
OpAddRule permissions.Operation = iota
OpViewRule
OpUpdateRule
OpUpdateRuleTags
OpUpdateRuleSchedule
OpListRules
OpRemoveRule
OpEnableRule
OpDisableRule
)
const (
OpAddRuleStr = "OpAddRule"
OpViewRuleStr = "OpViewRule"
OpUpdateRuleStr = "OpUpdateRule"
OpUpdateRuleTagsStr = "OpUpdateRuleTags"
OpUpdateRuleScheduleStr = "OpUpdateRuleSchedule"
OpListRulesStr = "OpListRules"
OpRemoveRuleStr = "OpRemoveRule"
OpEnableRuleStr = "OpEnableRule"
OpDisableRuleStr = "OpDisableRule"
)
func GetPermission(op permissions.Operation) (string, error) {
if op < OpAddRule || op > OpDisableRule {
return "", errors.New("invalid operation")
}
return policies.MembershipPermission, nil
}
func OperationName(op permissions.Operation) string {
switch op {
case OpAddRule:
return OpAddRuleStr
case OpViewRule:
return OpViewRuleStr
case OpUpdateRule:
return OpUpdateRuleStr
case OpUpdateRuleTags:
return OpUpdateRuleTagsStr
case OpUpdateRuleSchedule:
return OpUpdateRuleScheduleStr
case OpListRules:
return OpListRulesStr
case OpRemoveRule:
return OpRemoveRuleStr
case OpEnableRule:
return OpEnableRuleStr
case OpDisableRule:
return OpDisableRuleStr
default:
return "unknown"
}
}
+64
View File
@@ -0,0 +1,64 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package operations
import (
"github.com/absmach/supermq/pkg/permissions"
)
const EntityType = "rule"
// Rule Operations.
const (
OpAddRule permissions.Operation = iota
OpViewRule
OpUpdateRule
OpUpdateRuleTags
OpUpdateRuleSchedule
OpRemoveRule
OpListRules
OpEnableRule
OpDisableRule
)
func OperationDetails() map[permissions.Operation]permissions.OperationDetails {
return map[permissions.Operation]permissions.OperationDetails{
OpAddRule: {
Name: "add",
PermissionRequired: true,
},
OpViewRule: {
Name: "view",
PermissionRequired: true,
},
OpUpdateRule: {
Name: "update",
PermissionRequired: true,
},
OpUpdateRuleTags: {
Name: "update_tags",
PermissionRequired: true,
},
OpUpdateRuleSchedule: {
Name: "update_schedule",
PermissionRequired: true,
},
OpRemoveRule: {
Name: "delete",
PermissionRequired: true,
},
OpListRules: {
Name: "list",
PermissionRequired: true,
},
OpEnableRule: {
Name: "enable",
PermissionRequired: true,
},
OpDisableRule: {
Name: "disable",
PermissionRequired: true,
},
}
}
+20 -2
View File
@@ -4,12 +4,20 @@
package postgres
import (
dpostgres "github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
migrate "github.com/rubenv/sql-migrate"
)
func Migration() *migrate.MemoryMigrationSource {
return &migrate.MemoryMigrationSource{
func Migration() (*migrate.MemoryMigrationSource, error) {
rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName)
if err != nil {
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
}
rulesMigration := &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "rules_01",
@@ -65,4 +73,14 @@ func Migration() *migrate.MemoryMigrationSource {
},
},
}
rulesMigration.Migrations = append(rulesMigration.Migrations, rolesMigration.Migrations...)
domainsMigration, err := dpostgres.Migration()
if err != nil {
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
}
rulesMigration.Migrations = append(rulesMigration.Migrations, domainsMigration.Migrations...)
return rulesMigration, nil
}
+165 -1
View File
@@ -10,19 +10,32 @@ import (
"strings"
"time"
mgPolicies "github.com/absmach/magistrala/pkg/policies"
"github.com/absmach/magistrala/re"
api "github.com/absmach/supermq/api/http"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/postgres"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
)
const (
rolesTableNamePrefix = "rules"
entityTableName = "rules"
entityIDColumnName = "id"
)
type PostgresRepository struct {
DB postgres.Database
rolesPostgres.Repository
}
func NewRepository(db postgres.Database) re.Repository {
return &PostgresRepository{DB: db}
rolesRepo := rolesPostgres.NewRepository(db, mgPolicies.RulesType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
return &PostgresRepository{
DB: db,
Repository: rolesRepo,
}
}
func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) {
@@ -82,6 +95,157 @@ func (repo *PostgresRepository) ViewRule(ctx context.Context, id string) (re.Rul
return ret, nil
}
func (repo *PostgresRepository) RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (re.Rule, error) {
query := `
WITH selected_rule AS (
SELECT
r.id,
r.domain_id
FROM
rules r
WHERE
r.id = :id
LIMIT 1
),
selected_rule_roles AS (
SELECT
rr.entity_id AS rule_id,
rrm.member_id AS member_id,
rr.id AS role_id,
rr."name" AS role_name,
jsonb_agg(DISTINCT rra."action") AS actions,
'direct' AS access_type,
'' AS access_provider_id
FROM
rules_roles rr
JOIN
rules_role_members rrm ON rr.id = rrm.role_id
JOIN
rules_role_actions rra ON rr.id = rra.role_id
JOIN
selected_rule sr ON sr.id = rr.entity_id
AND rrm.member_id = :member_id
GROUP BY
rr.entity_id, rr.id, rr.name, rrm.member_id
),
selected_domain_roles AS (
SELECT
sr.id AS rule_id,
drm.member_id AS member_id,
dr.id AS role_id,
dr."name" AS role_name,
jsonb_agg(DISTINCT all_actions."action") AS actions,
'domain' AS access_type,
dr.entity_id AS access_provider_id
FROM
domains d
JOIN
selected_rule sr ON sr.domain_id = d.id
JOIN
domains_roles dr ON dr.entity_id = d.id
JOIN
domains_role_members drm ON dr.id = drm.role_id
JOIN
domains_role_actions dra ON dr.id = dra.role_id
JOIN
domains_role_actions all_actions ON dr.id = all_actions.role_id
WHERE
drm.member_id = :member_id
AND dra."action" LIKE 'rule%'
GROUP BY
sr.id, dr.entity_id, dr.id, dr."name", drm.member_id
),
all_roles AS (
SELECT
srr.rule_id,
srr.member_id,
srr.role_id,
srr.role_name,
srr.actions,
srr.access_type,
srr.access_provider_id
FROM
selected_rule_roles srr
UNION
SELECT
sdr.rule_id,
sdr.member_id,
sdr.role_id,
sdr.role_name,
sdr.actions,
sdr.access_type,
sdr.access_provider_id
FROM
selected_domain_roles sdr
),
final_roles AS (
SELECT
ar.rule_id,
ar.member_id,
jsonb_agg(
jsonb_build_object(
'role_id', ar.role_id,
'role_name', ar.role_name,
'actions', ar.actions,
'access_type', ar.access_type,
'access_provider_id', ar.access_provider_id
)
) AS roles
FROM all_roles ar
GROUP BY
ar.rule_id, ar.member_id
)
SELECT
r2.id,
r2."name",
r2.domain_id,
r2.tags,
r2.metadata,
r2.input_channel,
r2.input_topic,
r2.outputs,
r2.status,
r2.logic_type,
r2.logic_value,
r2.time,
r2.recurring,
r2.recurring_period,
r2.start_datetime,
r2.created_at,
r2.created_by,
r2.updated_at,
r2.updated_by,
fr.member_id,
fr.roles
FROM rules r2
JOIN final_roles fr ON fr.rule_id = r2.id
`
parameters := map[string]any{
"id": id,
"member_id": memberID,
}
row, err := repo.DB.NamedQueryContext(ctx, query, parameters)
if err != nil {
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbrule := dbRule{}
if !row.Next() {
return re.Rule{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbrule); err != nil {
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
r, err := dbToRule(dbrule)
if err != nil {
return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return r, nil
}
func (repo *PostgresRepository) UpdateRuleStatus(ctx context.Context, r re.Rule) (re.Rule, error) {
q := `UPDATE rules
SET status = :status, updated_at = :updated_at, updated_by = :updated_by
+11
View File
@@ -11,6 +11,7 @@ import (
"github.com/absmach/magistrala/pkg/schedule"
"github.com/absmach/magistrala/re"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/roles"
"github.com/jackc/pgtype"
)
@@ -35,6 +36,8 @@ type dbRule struct {
CreatedBy string `db:"created_by"`
UpdatedAt time.Time `db:"updated_at"`
UpdatedBy string `db:"updated_by"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
}
func ruleToDb(r re.Rule) (dbRule, error) {
@@ -108,6 +111,13 @@ func dbToRule(dto dbRule) (re.Rule, error) {
}
}
var roles []roles.MemberRoleActions
if dto.Roles != nil {
if err := json.Unmarshal(dto.Roles, &roles); err != nil {
return re.Rule{}, errors.Wrap(errors.ErrMalformedEntity, err)
}
}
return re.Rule{
ID: dto.ID,
Name: dto.Name,
@@ -132,6 +142,7 @@ func dbToRule(dto dbRule) (re.Rule, error) {
CreatedBy: dto.CreatedBy,
UpdatedAt: dto.UpdatedAt,
UpdatedBy: dto.UpdatedBy,
Roles: roles,
}, nil
}
+5 -1
View File
@@ -75,7 +75,11 @@ func TestMain(m *testing.M) {
SSLRootCert: "",
}
if db, err = postgres.Setup(dbConfig, *repostgres.Migration()); err != nil {
migration, err := repostgres.Migration()
if err != nil {
log.Fatalf("Could not get migration: %s", err)
}
if db, err = postgres.Setup(dbConfig, *migration); err != nil {
log.Fatalf("Could not setup test DB connection: %s", err)
}
+21 -16
View File
@@ -13,6 +13,7 @@ import (
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/roles"
)
const (
@@ -42,21 +43,22 @@ var outputRegistry = map[outputs.OutputType]func() Runnable{
}
type Rule struct {
ID string `json:"id"`
Name string `json:"name"`
DomainID string `json:"domain"`
Metadata Metadata `json:"metadata,omitempty"`
Tags []string `json:"tags,omitempty"`
InputChannel string `json:"input_channel"`
InputTopic string `json:"input_topic"`
Logic Script `json:"logic"`
Outputs Outputs `json:"outputs,omitempty"`
Schedule schedule.Schedule `json:"schedule,omitempty"`
Status Status `json:"status"`
CreatedAt time.Time `json:"created_at"`
CreatedBy string `json:"created_by"`
UpdatedAt time.Time `json:"updated_at"`
UpdatedBy string `json:"updated_by"`
ID string `json:"id"`
Name string `json:"name"`
DomainID string `json:"domain"`
Metadata Metadata `json:"metadata,omitempty"`
Tags []string `json:"tags,omitempty"`
InputChannel string `json:"input_channel"`
InputTopic string `json:"input_topic"`
Logic Script `json:"logic"`
Outputs Outputs `json:"outputs,omitempty"`
Schedule schedule.Schedule `json:"schedule,omitempty"`
Status Status `json:"status"`
CreatedAt time.Time `json:"created_at"`
CreatedBy string `json:"created_by"`
UpdatedAt time.Time `json:"updated_at"`
UpdatedBy string `json:"updated_by"`
Roles []roles.MemberRoleActions `json:"roles,omitempty"`
}
// EventEncode converts a Rule struct to map[string]any at event producer.
@@ -224,7 +226,7 @@ type Page struct {
type Service interface {
messaging.MessageHandler
AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error)
ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error)
UpdateRule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
UpdateRuleTags(ctx context.Context, session authn.Session, r Rule) (Rule, error)
UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error)
@@ -234,11 +236,13 @@ type Service interface {
DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error)
StartScheduler(ctx context.Context) error
roles.RoleManager
}
type Repository interface {
AddRule(ctx context.Context, r Rule) (Rule, error)
ViewRule(ctx context.Context, id string) (Rule, error)
RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (Rule, error)
UpdateRule(ctx context.Context, r Rule) (Rule, error)
UpdateRuleTags(ctx context.Context, r Rule) (Rule, error)
UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error)
@@ -246,4 +250,5 @@ type Repository interface {
UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error)
ListRules(ctx context.Context, pm PageMeta) (Page, error)
UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error)
roles.Repository
}
+57 -14
View File
@@ -11,11 +11,14 @@ import (
"github.com/absmach/magistrala/pkg/emailer"
pkglog "github.com/absmach/magistrala/pkg/logger"
"github.com/absmach/magistrala/pkg/ticker"
"github.com/absmach/magistrala/re/operations"
"github.com/absmach/supermq"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
)
var ErrGoroutinesNotAllowed = errors.New("goroutines are not allowed in Go scripts")
@@ -30,23 +33,29 @@ type re struct {
ticker ticker.Ticker
email emailer.Emailer
readers grpcReadersV1.ReadersServiceClient
roles.ProvisionManageService
}
func NewService(repo Repository, runInfo chan pkglog.RunInfo, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient) Service {
return &re{
repo: repo,
idp: idp,
runInfo: runInfo,
rePubSub: rePubSub,
writersPub: writersPub,
alarmsPub: alarmsPub,
ticker: tck,
email: emailer,
readers: readers,
func NewService(repo Repository, runInfo chan pkglog.RunInfo, policy policies.Service, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient, availableActions []roles.Action, builtInRoles map[roles.BuiltInRoleName][]roles.Action) (Service, error) {
rpms, err := roles.NewProvisionManageService(operations.EntityType, repo, policy, idp, availableActions, builtInRoles)
if err != nil {
return nil, err
}
return &re{
repo: repo,
idp: idp,
runInfo: runInfo,
rePubSub: rePubSub,
writersPub: writersPub,
alarmsPub: alarmsPub,
ticker: tck,
email: emailer,
readers: readers,
ProvisionManageService: rpms,
}, nil
}
func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) {
func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (retRule Rule, retErr error) {
if r.Logic.Type == GoType && goKeywordRegex.MatchString(r.Logic.Value) {
return Rule{}, errors.Wrap(svcerr.ErrMalformedEntity, ErrGoroutinesNotAllowed)
}
@@ -72,11 +81,45 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule,
return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
defer func() {
if retErr != nil {
if errRollBack := re.repo.RemoveRule(ctx, rule.ID); errRollBack != nil {
retErr = errors.Wrap(retErr, errors.Wrap(svcerr.ErrRollbackRepo, errRollBack))
}
}
}()
newBuiltInRoleMembers := map[roles.BuiltInRoleName][]roles.Member{
BuiltInRoleAdmin: {roles.Member(session.UserID)},
}
optionalPolicies := []policies.Policy{
{
SubjectType: policies.DomainType,
Subject: session.DomainID,
Relation: policies.DomainRelation,
ObjectType: operations.EntityType,
Object: rule.ID,
},
}
_, err = re.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{rule.ID}, optionalPolicies, newBuiltInRoleMembers)
if err != nil {
return Rule{}, errors.Wrap(svcerr.ErrAddPolicies, err)
}
return rule, nil
}
func (re *re) ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error) {
rule, err := re.repo.ViewRule(ctx, id)
func (re *re) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error) {
var rule Rule
var err error
switch withRoles {
case true:
rule, err = re.repo.RetrieveByIDWithRoles(ctx, id, session.UserID)
default:
rule, err = re.repo.ViewRule(ctx, id)
}
if err != nil {
return Rule{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
+197 -32
View File
@@ -26,6 +26,8 @@ import (
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
pubsubmocks "github.com/absmach/supermq/pkg/messaging/mocks"
policymocks "github.com/absmach/supermq/pkg/policies/mocks"
"github.com/absmach/supermq/pkg/roles"
"github.com/absmach/supermq/pkg/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -59,27 +61,40 @@ var (
}
)
func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer) {
func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer, *policymocks.Service) {
repo := new(mocks.Repository)
mockTicker := new(tmocks.Ticker)
idProvider := uuid.NewMock()
pubsub := pubsubmocks.NewPubSub(t)
readersSvc := new(readmocks.ReadersServiceClient)
e := new(emocks.Emailer)
return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker, e
policy := new(policymocks.Service)
availableActions := []roles.Action{}
builtInRoles := map[roles.BuiltInRoleName][]roles.Action{
"admin": availableActions,
}
svc, err := re.NewService(repo, runInfo, policy, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc, availableActions, builtInRoles)
if err != nil {
t.Fatalf("Failed to create service: %v", err)
}
return svc, repo, pubsub, mockTicker, e, policy
}
func TestAddRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo))
ruleName := namegen.Generate()
now := time.Now().Add(time.Hour)
cases := []struct {
desc string
session authn.Session
rule re.Rule
res re.Rule
err error
desc string
session authn.Session
rule re.Rule
res re.Rule
err error
addPoliciesErr error
deletePolicies error
addRoleErr error
deleteErr error
}{
{
desc: "Add rule successfully",
@@ -109,7 +124,10 @@ func TestAddRule(t *testing.T) {
CreatedBy: userID,
DomainID: domainID,
},
err: nil,
err: nil,
addPoliciesErr: nil,
addRoleErr: nil,
deleteErr: nil,
},
{
desc: "Add rule with failed repo",
@@ -126,7 +144,11 @@ func TestAddRule(t *testing.T) {
Time: now,
},
},
err: repoerr.ErrCreateEntity,
err: repoerr.ErrCreateEntity,
addPoliciesErr: nil,
deletePolicies: nil,
addRoleErr: nil,
deleteErr: nil,
},
{
desc: "Add rule with non-zero StartDateTime",
@@ -158,7 +180,136 @@ func TestAddRule(t *testing.T) {
CreatedBy: userID,
DomainID: domainID,
},
err: nil,
err: nil,
addPoliciesErr: nil,
addRoleErr: nil,
deleteErr: nil,
},
{
desc: "Add rule with failed to add roles and failed to delete policies",
session: authn.Session{
UserID: userID,
DomainID: domainID,
},
rule: re.Rule{
Name: ruleName,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
},
res: re.Rule{
Name: ruleName,
ID: ruleID,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
Status: re.EnabledStatus,
CreatedBy: userID,
DomainID: domainID,
},
addRoleErr: svcerr.ErrCreateEntity,
deletePolicies: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "Add rule with failed to add policies",
session: authn.Session{
UserID: userID,
DomainID: domainID,
},
rule: re.Rule{
Name: ruleName,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
},
res: re.Rule{
Name: ruleName,
ID: ruleID,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
Status: re.EnabledStatus,
CreatedBy: userID,
DomainID: domainID,
},
addPoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrAddPolicies,
},
{
desc: "Add rule with failed to add policies and failed rollback",
session: authn.Session{
UserID: userID,
DomainID: domainID,
},
rule: re.Rule{
Name: ruleName,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
},
res: re.Rule{
Name: ruleName,
ID: ruleID,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
Status: re.EnabledStatus,
CreatedBy: userID,
DomainID: domainID,
},
addPoliciesErr: svcerr.ErrAuthorization,
deleteErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRollbackRepo,
},
{
desc: "Add rule with failed to add roles",
session: authn.Session{
UserID: userID,
DomainID: domainID,
},
rule: re.Rule{
Name: ruleName,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
},
res: re.Rule{
Name: ruleName,
ID: ruleID,
InputChannel: inputChannel,
Schedule: pkgSch.Schedule{
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
},
Status: re.EnabledStatus,
CreatedBy: userID,
DomainID: domainID,
},
addRoleErr: svcerr.ErrCreateEntity,
err: svcerr.ErrAddPolicies,
},
{
desc: "Add rule with Go script containing goroutines",
@@ -186,6 +337,10 @@ func TestAddRule(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("AddRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
policyCall2 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePolicies).Maybe()
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
repoCall2 := repo.On("Remove", context.Background(), mock.Anything).Return(tc.deleteErr).Maybe()
res, err := svc.AddRule(context.Background(), tc.session, tc.rule)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@@ -193,14 +348,18 @@ func TestAddRule(t *testing.T) {
assert.Equal(t, tc.rule.Name, res.Name)
assert.Equal(t, tc.rule.Schedule, res.Schedule)
}
defer repoCall.Unset()
policyCall.Unset()
policyCall2.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
func TestViewRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
now := time.Now().Add(time.Hour)
cases := []struct {
@@ -246,7 +405,7 @@ func TestViewRule(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("ViewRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
res, err := svc.ViewRule(context.Background(), tc.session, tc.id)
res, err := svc.ViewRule(context.Background(), tc.session, tc.id, false)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@@ -259,7 +418,7 @@ func TestViewRule(t *testing.T) {
func TestUpdateRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
newName := namegen.Generate()
now := time.Now().Add(time.Hour)
@@ -370,7 +529,7 @@ func TestUpdateRule(t *testing.T) {
func TestUpdateRuleTags(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
cases := []struct {
desc string
@@ -427,7 +586,7 @@ func TestUpdateRuleTags(t *testing.T) {
func TestUpdateRuleSchedule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
now := time.Now().UTC()
future := now.Add(2 * time.Hour)
@@ -495,7 +654,7 @@ func TestUpdateRuleSchedule(t *testing.T) {
func TestListRules(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
numRules := 50
now := time.Now().Add(time.Hour)
var rules []re.Rule
@@ -629,13 +788,14 @@ func TestListRules(t *testing.T) {
func TestRemoveRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo))
cases := []struct {
desc string
session authn.Session
id string
err error
desc string
session authn.Session
id string
err error
deletePoliciesErr error
}{
{
desc: "remove rule successfully",
@@ -643,8 +803,9 @@ func TestRemoveRule(t *testing.T) {
UserID: userID,
DomainID: domainID,
},
id: ruleID,
err: nil,
id: ruleID,
err: nil,
deletePoliciesErr: nil,
},
{
desc: "remove rule with failed repo",
@@ -652,25 +813,28 @@ func TestRemoveRule(t *testing.T) {
UserID: userID,
DomainID: domainID,
},
id: ruleID,
err: svcerr.ErrRemoveEntity,
id: ruleID,
err: svcerr.ErrRemoveEntity,
deletePoliciesErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RemoveRule", mock.Anything, mock.Anything).Return(tc.err)
policyCall := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
err := svc.RemoveRule(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
defer repoCall.Unset()
policyCall.Unset()
repoCall.Unset()
})
}
}
func TestEnableRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
now := time.Now()
@@ -730,7 +894,7 @@ func TestEnableRule(t *testing.T) {
func TestDisableRule(t *testing.T) {
// nolint:dogsled
svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo))
svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo))
now := time.Now()
@@ -789,7 +953,7 @@ func TestDisableRule(t *testing.T) {
}
func TestHandle(t *testing.T) {
svc, repo, pubmocks, _, emailer := newService(t, make(chan pkglog.RunInfo))
svc, repo, pubmocks, _, emailer, _ := newService(t, make(chan pkglog.RunInfo))
now := time.Now()
scheduled := false
@@ -1461,7 +1625,8 @@ func TestHandle(t *testing.T) {
func TestStartScheduler(t *testing.T) {
now := time.Now().Truncate(time.Minute)
ri := make(chan pkglog.RunInfo)
svc, repo, _, ticker, _ := newService(t, ri)
// nolint:dogsled
svc, repo, _, ticker, _, _ := newService(t, ri)
ctxCases := []struct {
desc string