NOISSUE - Refactor listing for rules and reports (#433)

* add access control to rules engine

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* update authorization method

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* revert code

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* initial implementation

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* remove domain from method

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix userid parameter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* update checksuperadmin method

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* revert changes

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* address comments

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

---------

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2026-03-16 16:39:49 +03:00
committed by GitHub
parent 2ef8437d8b
commit 7fb5dd7b55
16 changed files with 826 additions and 120 deletions
+2 -2
View File
@@ -43,7 +43,7 @@ func (re *re) Handle(msg *messaging.Message) error {
Scheduled: &scheduledFalse,
}
ctx := context.Background()
page, err := re.repo.ListRules(ctx, pm)
page, err := re.repo.ListAllRules(ctx, pm)
if err != nil {
return err
}
@@ -130,7 +130,7 @@ func (re *re) StartScheduler(ctx context.Context) error {
ScheduledBefore: &due,
}
page, err := re.repo.ListRules(ctx, pm)
page, err := re.repo.ListAllRules(ctx, pm)
if err != nil {
re.runInfo <- pkglog.RunInfo{
Level: slog.LevelError,
+23 -2
View File
@@ -11,6 +11,7 @@ import (
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/messaging"
"github.com/absmach/supermq/pkg/permissions"
"github.com/absmach/supermq/pkg/policies"
@@ -90,8 +91,12 @@ func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, sessi
}
func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
if err := am.authorize(ctx, operations.OpListRules, session, policies.DomainType, session.DomainID); err != nil {
return re.Page{}, errors.Wrap(errDomainViewRules, err)
switch err := am.checkSuperAdmin(ctx, session); {
case err == nil:
session.SuperAdmin = true
case errors.Contains(err, svcerr.ErrSuperAdminAction):
default:
return re.Page{}, err
}
return am.svc.ListRules(ctx, session, pm)
@@ -168,3 +173,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
return nil
}
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
if session.Role != authn.SuperAdminRole {
return svcerr.ErrSuperAdminAction
}
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
SubjectType: policies.UserType,
Subject: session.UserID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}, nil); err != nil {
return err
}
return nil
}
+92 -20
View File
@@ -178,6 +178,72 @@ func (_c *Repository_AddRule_Call) RunAndReturn(run func(ctx context.Context, r
return _c
}
// ListAllRules provides a mock function for the type Repository
func (_mock *Repository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
ret := _mock.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for ListAllRules")
}
var r0 re.Page
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok {
return returnFunc(ctx, pm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok {
r0 = returnFunc(ctx, pm)
} else {
r0 = ret.Get(0).(re.Page)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok {
r1 = returnFunc(ctx, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_ListAllRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllRules'
type Repository_ListAllRules_Call struct {
*mock.Call
}
// ListAllRules is a helper method to define mock.On call
// - ctx context.Context
// - pm re.PageMeta
func (_e *Repository_Expecter) ListAllRules(ctx interface{}, pm interface{}) *Repository_ListAllRules_Call {
return &Repository_ListAllRules_Call{Call: _e.mock.On("ListAllRules", ctx, pm)}
}
func (_c *Repository_ListAllRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListAllRules_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 re.PageMeta
if args[1] != nil {
arg1 = args[1].(re.PageMeta)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_ListAllRules_Call) Return(page re.Page, err error) *Repository_ListAllRules_Call {
_c.Call.Return(page, err)
return _c
}
func (_c *Repository_ListAllRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListAllRules_Call {
_c.Call.Return(run)
return _c
}
// ListEntityMembers provides a mock function for the type Repository
func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
ret := _mock.Called(ctx, entityID, pageQuery)
@@ -250,68 +316,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C
return _c
}
// ListRules provides a mock function for the type Repository
func (_mock *Repository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
ret := _mock.Called(ctx, pm)
// ListUserRules provides a mock function for the type Repository
func (_mock *Repository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) {
ret := _mock.Called(ctx, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListRules")
panic("no return value specified for ListUserRules")
}
var r0 re.Page
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok {
return returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) (re.Page, error)); ok {
return returnFunc(ctx, userID, pm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok {
r0 = returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) re.Page); ok {
r0 = returnFunc(ctx, userID, pm)
} else {
r0 = ret.Get(0).(re.Page)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok {
r1 = returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(1).(func(context.Context, string, re.PageMeta) error); ok {
r1 = returnFunc(ctx, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_ListRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRules'
type Repository_ListRules_Call struct {
// Repository_ListUserRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRules'
type Repository_ListUserRules_Call struct {
*mock.Call
}
// ListRules is a helper method to define mock.On call
// ListUserRules is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - pm re.PageMeta
func (_e *Repository_Expecter) ListRules(ctx interface{}, pm interface{}) *Repository_ListRules_Call {
return &Repository_ListRules_Call{Call: _e.mock.On("ListRules", ctx, pm)}
func (_e *Repository_Expecter) ListUserRules(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserRules_Call {
return &Repository_ListUserRules_Call{Call: _e.mock.On("ListUserRules", ctx, userID, pm)}
}
func (_c *Repository_ListRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListRules_Call {
func (_c *Repository_ListUserRules_Call) Run(run func(ctx context.Context, userID string, pm re.PageMeta)) *Repository_ListUserRules_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 re.PageMeta
var arg1 string
if args[1] != nil {
arg1 = args[1].(re.PageMeta)
arg1 = args[1].(string)
}
var arg2 re.PageMeta
if args[2] != nil {
arg2 = args[2].(re.PageMeta)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Repository_ListRules_Call) Return(page re.Page, err error) *Repository_ListRules_Call {
func (_c *Repository_ListUserRules_Call) Return(page re.Page, err error) *Repository_ListUserRules_Call {
_c.Call.Return(page, err)
return _c
}
func (_c *Repository_ListRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListRules_Call {
func (_c *Repository_ListUserRules_Call) RunAndReturn(run func(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error)) *Repository_ListUserRules_Call {
_c.Call.Return(run)
return _c
}
+86 -26
View File
@@ -357,33 +357,10 @@ func (repo *PostgresRepository) RemoveRule(ctx context.Context, id string) error
return nil
}
func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
pgData := ""
if pm.Limit != 0 {
pgData = "LIMIT :limit"
}
if pm.Offset != 0 {
pgData += " OFFSET :offset"
}
func (repo *PostgresRepository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
pq := pageRulesQuery(pm)
dir := api.DescDir
if pm.Dir == api.AscDir {
dir = api.AscDir
}
orderClause := ""
switch pm.Order {
case api.NameKey:
orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
case api.CreatedAtOrder:
orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
case api.UpdatedAtOrder:
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
default:
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
}
orderClause := rulesOrderClause(pm)
pgData := rulesPageData(pm)
q := fmt.Sprintf(`
SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs,
@@ -425,6 +402,62 @@ func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) (
return ret, nil
}
func (repo *PostgresRepository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) {
pm.UserID = userID
pq := pageRulesQuery(pm)
orderClause := rulesOrderClause(pm)
pgData := rulesPageData(pm)
userJoin := `
INNER JOIN rules_roles rr ON rr.entity_id = r.id
INNER JOIN rules_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id
`
innerQ := fmt.Sprintf(`
SELECT DISTINCT r.id, r.name, r.domain_id, r.tags, r.input_channel, r.input_topic, r.logic_type, r.logic_value, r.outputs,
r.start_datetime, r.time, r.recurring, r.recurring_period, r.created_at, r.created_by, r.updated_at, r.updated_by, r.status
FROM rules r
%s
%s
`, userJoin, pq)
q := fmt.Sprintf(`
SELECT * FROM (%s) AS sub %s %s;
`, innerQ, orderClause, pgData)
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
if err != nil {
return re.Page{}, err
}
defer rows.Close()
var rules []re.Rule
for rows.Next() {
var r dbRule
if err := rows.StructScan(&r); err != nil {
return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
ret, err := dbToRule(r)
if err != nil {
return re.Page{}, err
}
rules = append(rules, ret)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ)
total, err := postgres.Total(ctx, repo.DB, cq, pm)
if err != nil {
return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return re.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
Rules: rules,
}, nil
}
func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, due time.Time) (re.Rule, error) {
q := `
UPDATE rules
@@ -460,6 +493,33 @@ func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, du
return rule, nil
}
func rulesOrderClause(pm re.PageMeta) string {
dir := api.DescDir
if pm.Dir == api.AscDir {
dir = api.AscDir
}
switch pm.Order {
case api.NameKey:
return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
case api.CreatedAtOrder:
return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
default:
return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
}
}
func rulesPageData(pm re.PageMeta) string {
pgData := ""
if pm.Limit != 0 {
pgData = "LIMIT :limit"
}
if pm.Offset != 0 {
pgData += " OFFSET :offset"
}
return pgData
}
func pageRulesQuery(pm re.PageMeta) string {
var query []string
if pm.InputChannel != "" {
+187 -1
View File
@@ -890,7 +890,7 @@ func TestListRules(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListRules(context.Background(), tc.pm)
page, err := repo.ListAllRules(context.Background(), tc.pm)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
@@ -935,6 +935,192 @@ func TestListRules(t *testing.T) {
}
}
func TestListUserRules(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := generateUUID(t)
userID := generateUUID(t)
otherUserID := generateUUID(t)
channelID := generateUUID(t)
// Create 10 rules; assign the first 4 to userID via a role.
var allRules []re.Rule
for i := range 10 {
r := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: domainID,
InputChannel: channelID,
Logic: re.Script{Type: re.LuaType, Value: "return true"},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), r)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
allRules = append(allRules, rule)
}
// Assign userID to the first 4 rules via direct role INSERT.
for i := range 4 {
roleID := generateUUID(t)
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allRules[i].ID)
assert.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
_, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allRules[i].ID)
assert.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
}
cases := []struct {
desc string
userID string
pm re.PageMeta
count int
err error
}{
{
desc: "list user rules returns only accessible rules",
userID: userID,
pm: re.PageMeta{
Offset: 0,
Limit: 100,
Status: re.AllStatus,
},
count: 4,
err: nil,
},
{
desc: "list user rules with offset",
userID: userID,
pm: re.PageMeta{
Offset: 2,
Limit: 100,
Status: re.AllStatus,
},
count: 2,
err: nil,
},
{
desc: "list user rules with limit",
userID: userID,
pm: re.PageMeta{
Offset: 0,
Limit: 2,
Status: re.AllStatus,
},
count: 2,
err: nil,
},
{
desc: "list user rules with domain filter",
userID: userID,
pm: re.PageMeta{
Domain: domainID,
Offset: 0,
Limit: 100,
Status: re.AllStatus,
},
count: 4,
err: nil,
},
{
desc: "list user rules with channel filter",
userID: userID,
pm: re.PageMeta{
InputChannel: channelID,
Offset: 0,
Limit: 100,
Status: re.AllStatus,
},
count: 4,
err: nil,
},
{
desc: "list user rules with non-existing domain returns 0",
userID: userID,
pm: re.PageMeta{
Domain: generateUUID(t),
Offset: 0,
Limit: 100,
Status: re.AllStatus,
},
count: 0,
err: nil,
},
{
desc: "list rules for user with no role assignments returns 0",
userID: otherUserID,
pm: re.PageMeta{
Offset: 0,
Limit: 100,
Status: re.AllStatus,
},
count: 0,
err: nil,
},
{
desc: "list user rules ordered by name ascending",
userID: userID,
pm: re.PageMeta{
Offset: 0,
Limit: 100,
Status: re.AllStatus,
Order: nameOrder,
Dir: ascDir,
},
count: 4,
err: nil,
},
{
desc: "list user rules ordered by created_at descending",
userID: userID,
pm: re.PageMeta{
Offset: 0,
Limit: 100,
Status: re.AllStatus,
Order: createdAtOrder,
Dir: descDir,
},
count: 4,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListUserRules(context.Background(), tc.userID, tc.pm)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
}
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules)))
if len(page.Rules) > 1 {
switch tc.pm.Order {
case nameOrder:
if tc.pm.Dir == ascDir {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].Name <= page.Rules[j].Name
}), "Expected names to be sorted ascending")
}
case createdAtOrder:
if tc.pm.Dir == descDir {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt)
}), "Expected created_at to be sorted descending")
}
}
}
})
}
}
func TestRemoveRule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
+3 -1
View File
@@ -172,6 +172,7 @@ type PageMeta struct {
ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time
ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time
Recurring *schedule.Recurring `json:"recurring,omitempty" db:"recurring"` // Filter by recurring type
UserID string `json:"user_id,omitempty" db:"user_id"`
}
// EventEncode converts a PageMeta struct to map[string]any.
@@ -250,7 +251,8 @@ type Repository interface {
UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error)
RemoveRule(ctx context.Context, id string) error
UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error)
ListRules(ctx context.Context, pm PageMeta) (Page, error)
ListAllRules(ctx context.Context, pm PageMeta) (Page, error)
ListUserRules(ctx context.Context, userID string, pm PageMeta) (Page, error)
UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error)
roles.Repository
}
+8 -1
View File
@@ -175,7 +175,14 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R
func (re *re) ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) {
pm.Domain = session.DomainID
page, err := re.repo.ListRules(ctx, pm)
if session.SuperAdmin {
page, err := re.repo.ListAllRules(ctx, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return page, nil
}
page, err := re.repo.ListUserRules(ctx, session.UserID, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
+43 -9
View File
@@ -868,11 +868,12 @@ func TestListRules(t *testing.T) {
}
cases := []struct {
desc string
session authn.Session
pageMeta re.PageMeta
res re.Page
err error
desc string
session authn.Session
pageMeta re.PageMeta
res re.Page
err error
superAdmin bool
}{
{
desc: "list rules successfully",
@@ -948,11 +949,44 @@ func TestListRules(t *testing.T) {
pageMeta: re.PageMeta{},
err: svcerr.ErrViewEntity,
},
{
desc: "list rules as super admin successfully",
session: authn.Session{
UserID: userID,
DomainID: domainID,
SuperAdmin: true,
},
pageMeta: re.PageMeta{},
res: re.Page{
Total: uint64(numRules),
Offset: 0,
Limit: 10,
Rules: rules[0:10],
},
superAdmin: true,
err: nil,
},
{
desc: "list rules as super admin with failed repo",
session: authn.Session{
UserID: userID,
DomainID: domainID,
SuperAdmin: true,
},
pageMeta: re.PageMeta{},
superAdmin: true,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.res, tc.err)
var repoCall *mock.Call
if tc.superAdmin {
repoCall = repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.res, tc.err)
} else {
repoCall = repo.On("ListUserRules", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err)
}
res, err := svc.ListRules(context.Background(), tc.session, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
@@ -1778,7 +1812,7 @@ func TestHandle(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
var err error
repoCall := repo.On("ListRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) {
repoCall := repo.On("ListAllRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) {
if tc.listErr != nil {
err = tc.listErr
}
@@ -1854,7 +1888,7 @@ func TestStartScheduler(t *testing.T) {
for _, tc := range ctxCases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr)
repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr)
tickChan := make(chan time.Time)
tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan))
tickCall1 := ticker.On("Stop").Return()
@@ -1989,7 +2023,7 @@ func TestStartScheduler(t *testing.T) {
Total: uint64(len(tc.rules)),
}
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(page, tc.listErr)
repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(page, tc.listErr)
repoCall2 := repo.On("UpdateRuleDue", mock.Anything, mock.Anything, mock.Anything).Return(re.Rule{}, tc.updateDueErr)
tickChan := make(chan time.Time, 1)
tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan))
+1 -1
View File
@@ -27,7 +27,7 @@ func (r *report) StartScheduler(ctx context.Context) error {
ScheduledBefore: &due,
}
reportConfigs, err := r.repo.ListReportsConfig(ctx, pm)
reportConfigs, err := r.repo.ListAllReportsConfig(ctx, pm)
if err != nil {
r.runInfo <- pkglog.RunInfo{
Level: slog.LevelError,
+23 -2
View File
@@ -11,6 +11,7 @@ import (
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/permissions"
"github.com/absmach/supermq/pkg/policies"
rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
@@ -93,8 +94,12 @@ func (am *authorizationMiddleware) RemoveReportConfig(ctx context.Context, sessi
}
func (am *authorizationMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) {
if err := am.authorize(ctx, operations.OpListReportsConfig, session, policies.DomainType, session.DomainID); err != nil {
return reports.ReportConfigPage{}, errors.Wrap(errDomainViewConfigs, err)
switch err := am.checkSuperAdmin(ctx, session); {
case err == nil:
session.SuperAdmin = true
case errors.Contains(err, svcerr.ErrSuperAdminAction):
default:
return reports.ReportConfigPage{}, err
}
return am.svc.ListReportsConfig(ctx, session, pm)
@@ -187,3 +192,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
return nil
}
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
if session.Role != authn.SuperAdminRole {
return svcerr.ErrSuperAdminAction
}
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
SubjectType: policies.UserType,
Subject: session.UserID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}, nil); err != nil {
return err
}
return nil
}
+92 -20
View File
@@ -241,6 +241,72 @@ func (_c *Repository_DeleteReportTemplate_Call) RunAndReturn(run func(ctx contex
return _c
}
// ListAllReportsConfig provides a mock function for the type Repository
func (_mock *Repository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
ret := _mock.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for ListAllReportsConfig")
}
var r0 reports.ReportConfigPage
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
return returnFunc(ctx, pm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok {
r0 = returnFunc(ctx, pm)
} else {
r0 = ret.Get(0).(reports.ReportConfigPage)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok {
r1 = returnFunc(ctx, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_ListAllReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllReportsConfig'
type Repository_ListAllReportsConfig_Call struct {
*mock.Call
}
// ListAllReportsConfig is a helper method to define mock.On call
// - ctx context.Context
// - pm reports.PageMeta
func (_e *Repository_Expecter) ListAllReportsConfig(ctx interface{}, pm interface{}) *Repository_ListAllReportsConfig_Call {
return &Repository_ListAllReportsConfig_Call{Call: _e.mock.On("ListAllReportsConfig", ctx, pm)}
}
func (_c *Repository_ListAllReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListAllReportsConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 reports.PageMeta
if args[1] != nil {
arg1 = args[1].(reports.PageMeta)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_ListAllReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListAllReportsConfig_Call {
_c.Call.Return(reportConfigPage, err)
return _c
}
func (_c *Repository_ListAllReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListAllReportsConfig_Call {
_c.Call.Return(run)
return _c
}
// ListEntityMembers provides a mock function for the type Repository
func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
ret := _mock.Called(ctx, entityID, pageQuery)
@@ -313,68 +379,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C
return _c
}
// ListReportsConfig provides a mock function for the type Repository
func (_mock *Repository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
ret := _mock.Called(ctx, pm)
// ListUserReportsConfig provides a mock function for the type Repository
func (_mock *Repository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) {
ret := _mock.Called(ctx, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListReportsConfig")
panic("no return value specified for ListUserReportsConfig")
}
var r0 reports.ReportConfigPage
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
return returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
return returnFunc(ctx, userID, pm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok {
r0 = returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) reports.ReportConfigPage); ok {
r0 = returnFunc(ctx, userID, pm)
} else {
r0 = ret.Get(0).(reports.ReportConfigPage)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok {
r1 = returnFunc(ctx, pm)
if returnFunc, ok := ret.Get(1).(func(context.Context, string, reports.PageMeta) error); ok {
r1 = returnFunc(ctx, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_ListReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListReportsConfig'
type Repository_ListReportsConfig_Call struct {
// Repository_ListUserReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserReportsConfig'
type Repository_ListUserReportsConfig_Call struct {
*mock.Call
}
// ListReportsConfig is a helper method to define mock.On call
// ListUserReportsConfig is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - pm reports.PageMeta
func (_e *Repository_Expecter) ListReportsConfig(ctx interface{}, pm interface{}) *Repository_ListReportsConfig_Call {
return &Repository_ListReportsConfig_Call{Call: _e.mock.On("ListReportsConfig", ctx, pm)}
func (_e *Repository_Expecter) ListUserReportsConfig(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserReportsConfig_Call {
return &Repository_ListUserReportsConfig_Call{Call: _e.mock.On("ListUserReportsConfig", ctx, userID, pm)}
}
func (_c *Repository_ListReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListReportsConfig_Call {
func (_c *Repository_ListUserReportsConfig_Call) Run(run func(ctx context.Context, userID string, pm reports.PageMeta)) *Repository_ListUserReportsConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 reports.PageMeta
var arg1 string
if args[1] != nil {
arg1 = args[1].(reports.PageMeta)
arg1 = args[1].(string)
}
var arg2 reports.PageMeta
if args[2] != nil {
arg2 = args[2].(reports.PageMeta)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Repository_ListReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListReportsConfig_Call {
func (_c *Repository_ListUserReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListUserReportsConfig_Call {
_c.Call.Return(reportConfigPage, err)
return _c
}
func (_c *Repository_ListReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListReportsConfig_Call {
func (_c *Repository_ListUserReportsConfig_Call) RunAndReturn(run func(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListUserReportsConfig_Call {
_c.Call.Return(run)
return _c
}
+86 -26
View File
@@ -405,39 +405,16 @@ func (repo *PostgresRepository) RemoveReportConfig(ctx context.Context, id strin
return nil
}
func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
func (repo *PostgresRepository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
listReportsQuery := `
SELECT id, name, description, domain_id, metrics, email, config,
start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status
FROM report_config rc %s %s %s;
`
pgData := ""
if pm.Limit != 0 {
pgData = "LIMIT :limit"
}
if pm.Offset != 0 {
pgData += " OFFSET :offset"
}
pq := pageReportQuery(pm)
dir := api.DescDir
if pm.Dir == api.AscDir {
dir = api.AscDir
}
orderClause := ""
switch pm.Order {
case api.NameKey:
orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
case api.CreatedAtOrder:
orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
case api.UpdatedAtOrder:
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
default:
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
}
orderClause := reportsOrderClause(pm)
pgData := reportsPageData(pm)
q := fmt.Sprintf(listReportsQuery, pq, orderClause, pgData)
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
@@ -474,6 +451,63 @@ func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm report
return ret, nil
}
func (repo *PostgresRepository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) {
pq := pageReportQuery(pm)
orderClause := reportsOrderClause(pm)
pgData := reportsPageData(pm)
pm.UserID = userID
userJoin := `
INNER JOIN reports_roles rr ON rr.entity_id = rc.id
INNER JOIN reports_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id
`
whereClause := pq
innerQ := fmt.Sprintf(`
SELECT DISTINCT rc.id, rc.name, rc.description, rc.domain_id, rc.metrics, rc.email, rc.config,
rc.start_datetime, rc.due, rc.recurring, rc.recurring_period, rc.created_at, rc.created_by, rc.updated_at, rc.updated_by, rc.status
FROM report_config rc
%s
%s
`, userJoin, whereClause)
q := fmt.Sprintf(`
SELECT * FROM (%s) AS sub %s %s;
`, innerQ, orderClause, pgData)
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
if err != nil {
return reports.ReportConfigPage{}, err
}
defer rows.Close()
cfgs := []reports.ReportConfig{}
for rows.Next() {
var r dbReport
if err := rows.StructScan(&r); err != nil {
return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
rpt, err := dbToReport(r)
if err != nil {
return reports.ReportConfigPage{}, err
}
cfgs = append(cfgs, rpt)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ)
total, err := postgres.Total(ctx, repo.DB, cq, pm)
if err != nil {
return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
pm.Total = total
return reports.ReportConfigPage{
PageMeta: pm,
ReportConfigs: cfgs,
}, nil
}
func (repo *PostgresRepository) UpdateReportDue(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error) {
q := `
UPDATE report_config
@@ -576,6 +610,32 @@ func (repo *PostgresRepository) DeleteReportTemplate(ctx context.Context, domain
return nil
}
func reportsOrderClause(pm reports.PageMeta) string {
dir := api.DescDir
if pm.Dir == api.AscDir {
dir = api.AscDir
}
switch pm.Order {
case api.NameKey:
return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
case api.CreatedAtOrder:
return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
default:
return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
}
}
func reportsPageData(pm reports.PageMeta) string {
pgData := ""
if pm.Limit != 0 {
pgData = "LIMIT :limit"
}
if pm.Offset != 0 {
pgData += " OFFSET :offset"
}
return pgData
}
func pageReportQuery(pm reports.PageMeta) string {
var query []string
if pm.Status != reports.AllStatus {
+127 -1
View File
@@ -472,7 +472,7 @@ func TestListReportsConfig(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListReportsConfig(context.Background(), tc.pageMeta)
page, err := repo.ListAllReportsConfig(context.Background(), tc.pageMeta)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
@@ -483,6 +483,132 @@ func TestListReportsConfig(t *testing.T) {
}
}
func TestListUserReportsConfig(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM report_config")
require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := generateUUID(t)
userID := generateUUID(t)
otherUserID := generateUUID(t)
num := 10
var allCfgs []reports.ReportConfig
for i := range num {
cfg := reports.ReportConfig{
ID: generateUUID(t),
Name: fmt.Sprintf("Report-%d", i),
DomainID: domainID,
Status: reports.EnabledStatus,
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
Metrics: []reports.ReqMetric{},
}
cfg, err := repo.AddReportConfig(context.Background(), cfg)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
allCfgs = append(allCfgs, cfg)
}
// Assign userID to the first 5 report configs via direct role INSERT.
for i := range 5 {
roleID := generateUUID(t)
_, err := db.Exec(`INSERT INTO reports_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allCfgs[i].ID)
require.Nil(t, err, fmt.Sprintf("insert reports_roles unexpected error: %s", err))
_, err = db.Exec(`INSERT INTO reports_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allCfgs[i].ID)
require.Nil(t, err, fmt.Sprintf("insert reports_role_members unexpected error: %s", err))
}
cases := []struct {
desc string
userID string
pageMeta reports.PageMeta
size int
err error
}{
{
desc: "list user reports returns only accessible reports",
userID: userID,
pageMeta: reports.PageMeta{
Domain: domainID,
Limit: 100,
Offset: 0,
},
size: 5,
err: nil,
},
{
desc: "list user reports with limit",
userID: userID,
pageMeta: reports.PageMeta{
Domain: domainID,
Limit: 3,
Offset: 0,
},
size: 3,
err: nil,
},
{
desc: "list user reports with offset",
userID: userID,
pageMeta: reports.PageMeta{
Domain: domainID,
Limit: 100,
Offset: 3,
},
size: 2,
err: nil,
},
{
desc: "list user reports with enabled status filter",
userID: userID,
pageMeta: reports.PageMeta{
Domain: domainID,
Limit: 100,
Status: reports.EnabledStatus,
},
size: 5,
err: nil,
},
{
desc: "list reports for user with no role assignments returns 0",
userID: otherUserID,
pageMeta: reports.PageMeta{
Domain: domainID,
Limit: 100,
Offset: 0,
},
size: 0,
err: nil,
},
{
desc: "list user reports with non-existing domain returns 0",
userID: userID,
pageMeta: reports.PageMeta{
Domain: generateUUID(t),
Limit: 100,
Offset: 0,
},
size: 0,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListUserReportsConfig(context.Background(), tc.userID, tc.pageMeta)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
}
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
require.Equal(t, tc.size, len(page.ReportConfigs), fmt.Sprintf("%s: expected %d reports, got %d", tc.desc, tc.size, len(page.ReportConfigs)))
})
}
}
func TestUpdateReportSchedule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM report_config")
+3 -1
View File
@@ -395,6 +395,7 @@ type PageMeta struct {
Domain string `json:"domain_id,omitempty" db:"domain_id"`
ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time
ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time
UserID string `json:"user_id,omitempty" db:"user_id"`
}
type Repository interface {
@@ -405,7 +406,8 @@ type Repository interface {
UpdateReportSchedule(ctx context.Context, cfg ReportConfig) (ReportConfig, error)
RemoveReportConfig(ctx context.Context, id string) error
UpdateReportConfigStatus(ctx context.Context, cfg ReportConfig) (ReportConfig, error)
ListReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error)
ListAllReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error)
ListUserReportsConfig(ctx context.Context, userID string, pm PageMeta) (ReportConfigPage, error)
UpdateReportDue(ctx context.Context, id string, due time.Time) (ReportConfig, error)
UpdateReportTemplate(ctx context.Context, domainID, reportID string, template ReportTemplate) error
+8 -1
View File
@@ -159,7 +159,14 @@ func (r *report) RemoveReportConfig(ctx context.Context, session authn.Session,
func (r *report) ListReportsConfig(ctx context.Context, session authn.Session, pm PageMeta) (ReportConfigPage, error) {
pm.Domain = session.DomainID
page, err := r.repo.ListReportsConfig(ctx, pm)
if session.SuperAdmin {
page, err := r.repo.ListAllReportsConfig(ctx, pm)
if err != nil {
return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return page, nil
}
page, err := r.repo.ListUserReportsConfig(ctx, session.UserID, pm)
if err != nil {
return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
+42 -6
View File
@@ -328,11 +328,12 @@ func TestListReportsConfig(t *testing.T) {
}
cases := []struct {
desc string
session authn.Session
pageMeta reports.PageMeta
res reports.ReportConfigPage
err error
desc string
session authn.Session
pageMeta reports.PageMeta
res reports.ReportConfigPage
err error
superAdmin bool
}{
{
desc: "list report configs successfully",
@@ -399,11 +400,46 @@ func TestListReportsConfig(t *testing.T) {
pageMeta: reports.PageMeta{},
err: svcerr.ErrViewEntity,
},
{
desc: "list report configs as super admin successfully",
session: authn.Session{
UserID: userID,
DomainID: domainID,
SuperAdmin: true,
},
pageMeta: reports.PageMeta{},
res: reports.ReportConfigPage{
PageMeta: reports.PageMeta{
Total: uint64(numConfigs),
Offset: 0,
Limit: 10,
},
ReportConfigs: configs[0:10],
},
superAdmin: true,
err: nil,
},
{
desc: "list report configs as super admin with failed repo",
session: authn.Session{
UserID: userID,
DomainID: domainID,
SuperAdmin: true,
},
pageMeta: reports.PageMeta{},
superAdmin: true,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("ListReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err)
var repoCall *mock.Call
if tc.superAdmin {
repoCall = repo.On("ListAllReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err)
} else {
repoCall = repo.On("ListUserReportsConfig", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err)
}
res, err := svc.ListReportsConfig(context.Background(), tc.session, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {