From 7fb5dd7b55b88748b2b3a326f14b01ed78c91ce9 Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Mon, 16 Mar 2026 16:39:49 +0300 Subject: [PATCH] NOISSUE - Refactor listing for rules and reports (#433) * add access control to rules engine Signed-off-by: nyagamunene * update authorization method Signed-off-by: nyagamunene * revert code Signed-off-by: nyagamunene * initial implementation Signed-off-by: nyagamunene * remove domain from method Signed-off-by: nyagamunene * fix failing linter Signed-off-by: nyagamunene * fix userid parameter Signed-off-by: nyagamunene * update checksuperadmin method Signed-off-by: nyagamunene * revert changes Signed-off-by: nyagamunene * address comments Signed-off-by: nyagamunene --------- Signed-off-by: nyagamunene --- re/handlers.go | 4 +- re/middleware/authorization.go | 25 +++- re/mocks/repository.go | 112 ++++++++++++++--- re/postgres/repository.go | 112 +++++++++++++---- re/postgres/repository_test.go | 188 +++++++++++++++++++++++++++- re/rule.go | 4 +- re/service.go | 9 +- re/service_test.go | 52 ++++++-- reports/handler.go | 2 +- reports/middleware/authorization.go | 25 +++- reports/mocks/repository.go | 112 ++++++++++++++--- reports/postgres/repository.go | 112 +++++++++++++---- reports/postgres/repository_test.go | 128 ++++++++++++++++++- reports/reports.go | 4 +- reports/service.go | 9 +- reports/service_test.go | 48 ++++++- 16 files changed, 826 insertions(+), 120 deletions(-) diff --git a/re/handlers.go b/re/handlers.go index 851890feb..455b0e5f0 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -43,7 +43,7 @@ func (re *re) Handle(msg *messaging.Message) error { Scheduled: &scheduledFalse, } ctx := context.Background() - page, err := re.repo.ListRules(ctx, pm) + page, err := re.repo.ListAllRules(ctx, pm) if err != nil { return err } @@ -130,7 +130,7 @@ func (re *re) StartScheduler(ctx context.Context) error { ScheduledBefore: &due, } - page, err := re.repo.ListRules(ctx, pm) + page, err := re.repo.ListAllRules(ctx, pm) if err != nil { re.runInfo <- pkglog.RunInfo{ Level: slog.LevelError, diff --git a/re/middleware/authorization.go b/re/middleware/authorization.go index b957934f1..4834d41fb 100644 --- a/re/middleware/authorization.go +++ b/re/middleware/authorization.go @@ -11,6 +11,7 @@ import ( "github.com/absmach/supermq/pkg/authn" smqauthz "github.com/absmach/supermq/pkg/authz" "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/messaging" "github.com/absmach/supermq/pkg/permissions" "github.com/absmach/supermq/pkg/policies" @@ -90,8 +91,12 @@ func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, sessi } func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { - if err := am.authorize(ctx, operations.OpListRules, session, policies.DomainType, session.DomainID); err != nil { - return re.Page{}, errors.Wrap(errDomainViewRules, err) + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: + session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return re.Page{}, err } return am.svc.ListRules(ctx, session, pm) @@ -168,3 +173,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions return nil } + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error { + if session.Role != authn.SuperAdminRole { + return svcerr.ErrSuperAdminAction + } + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + SubjectType: policies.UserType, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.SuperMQObject, + }, nil); err != nil { + return err + } + return nil +} diff --git a/re/mocks/repository.go b/re/mocks/repository.go index 730351ab4..d26d79b90 100644 --- a/re/mocks/repository.go +++ b/re/mocks/repository.go @@ -178,6 +178,72 @@ func (_c *Repository_AddRule_Call) RunAndReturn(run func(ctx context.Context, r return _c } +// ListAllRules provides a mock function for the type Repository +func (_mock *Repository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAllRules") + } + + var r0 re.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(re.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAllRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllRules' +type Repository_ListAllRules_Call struct { + *mock.Call +} + +// ListAllRules is a helper method to define mock.On call +// - ctx context.Context +// - pm re.PageMeta +func (_e *Repository_Expecter) ListAllRules(ctx interface{}, pm interface{}) *Repository_ListAllRules_Call { + return &Repository_ListAllRules_Call{Call: _e.mock.On("ListAllRules", ctx, pm)} +} + +func (_c *Repository_ListAllRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListAllRules_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.PageMeta + if args[1] != nil { + arg1 = args[1].(re.PageMeta) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListAllRules_Call) Return(page re.Page, err error) *Repository_ListAllRules_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *Repository_ListAllRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListAllRules_Call { + _c.Call.Return(run) + return _c +} + // ListEntityMembers provides a mock function for the type Repository func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) { ret := _mock.Called(ctx, entityID, pageQuery) @@ -250,68 +316,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C return _c } -// ListRules provides a mock function for the type Repository -func (_mock *Repository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { - ret := _mock.Called(ctx, pm) +// ListUserRules provides a mock function for the type Repository +func (_mock *Repository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) { + ret := _mock.Called(ctx, userID, pm) if len(ret) == 0 { - panic("no return value specified for ListRules") + panic("no return value specified for ListUserRules") } var r0 re.Page var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok { - return returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) (re.Page, error)); ok { + return returnFunc(ctx, userID, pm) } - if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok { - r0 = returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) re.Page); ok { + r0 = returnFunc(ctx, userID, pm) } else { r0 = ret.Get(0).(re.Page) } - if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok { - r1 = returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, re.PageMeta) error); ok { + r1 = returnFunc(ctx, userID, pm) } else { r1 = ret.Error(1) } return r0, r1 } -// Repository_ListRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRules' -type Repository_ListRules_Call struct { +// Repository_ListUserRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRules' +type Repository_ListUserRules_Call struct { *mock.Call } -// ListRules is a helper method to define mock.On call +// ListUserRules is a helper method to define mock.On call // - ctx context.Context +// - userID string // - pm re.PageMeta -func (_e *Repository_Expecter) ListRules(ctx interface{}, pm interface{}) *Repository_ListRules_Call { - return &Repository_ListRules_Call{Call: _e.mock.On("ListRules", ctx, pm)} +func (_e *Repository_Expecter) ListUserRules(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserRules_Call { + return &Repository_ListUserRules_Call{Call: _e.mock.On("ListUserRules", ctx, userID, pm)} } -func (_c *Repository_ListRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListRules_Call { +func (_c *Repository_ListUserRules_Call) Run(run func(ctx context.Context, userID string, pm re.PageMeta)) *Repository_ListUserRules_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 re.PageMeta + var arg1 string if args[1] != nil { - arg1 = args[1].(re.PageMeta) + arg1 = args[1].(string) + } + var arg2 re.PageMeta + if args[2] != nil { + arg2 = args[2].(re.PageMeta) } run( arg0, arg1, + arg2, ) }) return _c } -func (_c *Repository_ListRules_Call) Return(page re.Page, err error) *Repository_ListRules_Call { +func (_c *Repository_ListUserRules_Call) Return(page re.Page, err error) *Repository_ListUserRules_Call { _c.Call.Return(page, err) return _c } -func (_c *Repository_ListRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListRules_Call { +func (_c *Repository_ListUserRules_Call) RunAndReturn(run func(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error)) *Repository_ListUserRules_Call { _c.Call.Return(run) return _c } diff --git a/re/postgres/repository.go b/re/postgres/repository.go index a64b8ba67..1c8aaa34a 100644 --- a/re/postgres/repository.go +++ b/re/postgres/repository.go @@ -357,33 +357,10 @@ func (repo *PostgresRepository) RemoveRule(ctx context.Context, id string) error return nil } -func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { - pgData := "" - if pm.Limit != 0 { - pgData = "LIMIT :limit" - } - if pm.Offset != 0 { - pgData += " OFFSET :offset" - } +func (repo *PostgresRepository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { pq := pageRulesQuery(pm) - - dir := api.DescDir - if pm.Dir == api.AscDir { - dir = api.AscDir - } - - orderClause := "" - - switch pm.Order { - case api.NameKey: - orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) - case api.CreatedAtOrder: - orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) - case api.UpdatedAtOrder: - orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) - default: - orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) - } + orderClause := rulesOrderClause(pm) + pgData := rulesPageData(pm) q := fmt.Sprintf(` SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs, @@ -425,6 +402,62 @@ func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) ( return ret, nil } +func (repo *PostgresRepository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) { + pm.UserID = userID + pq := pageRulesQuery(pm) + orderClause := rulesOrderClause(pm) + pgData := rulesPageData(pm) + + userJoin := ` + INNER JOIN rules_roles rr ON rr.entity_id = r.id + INNER JOIN rules_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id + ` + + innerQ := fmt.Sprintf(` + SELECT DISTINCT r.id, r.name, r.domain_id, r.tags, r.input_channel, r.input_topic, r.logic_type, r.logic_value, r.outputs, + r.start_datetime, r.time, r.recurring, r.recurring_period, r.created_at, r.created_by, r.updated_at, r.updated_by, r.status + FROM rules r + %s + %s + `, userJoin, pq) + + q := fmt.Sprintf(` + SELECT * FROM (%s) AS sub %s %s; + `, innerQ, orderClause, pgData) + + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return re.Page{}, err + } + defer rows.Close() + + var rules []re.Rule + for rows.Next() { + var r dbRule + if err := rows.StructScan(&r); err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + ret, err := dbToRule(r) + if err != nil { + return re.Page{}, err + } + rules = append(rules, ret) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ) + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return re.Page{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Rules: rules, + }, nil +} + func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, due time.Time) (re.Rule, error) { q := ` UPDATE rules @@ -460,6 +493,33 @@ func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, du return rule, nil } +func rulesOrderClause(pm re.PageMeta) string { + dir := api.DescDir + if pm.Dir == api.AscDir { + dir = api.AscDir + } + + switch pm.Order { + case api.NameKey: + return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) + case api.CreatedAtOrder: + return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) + default: + return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) + } +} + +func rulesPageData(pm re.PageMeta) string { + pgData := "" + if pm.Limit != 0 { + pgData = "LIMIT :limit" + } + if pm.Offset != 0 { + pgData += " OFFSET :offset" + } + return pgData +} + func pageRulesQuery(pm re.PageMeta) string { var query []string if pm.InputChannel != "" { diff --git a/re/postgres/repository_test.go b/re/postgres/repository_test.go index 61c2cded4..62150f113 100644 --- a/re/postgres/repository_test.go +++ b/re/postgres/repository_test.go @@ -890,7 +890,7 @@ func TestListRules(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - page, err := repo.ListRules(context.Background(), tc.pm) + page, err := repo.ListAllRules(context.Background(), tc.pm) if tc.err != nil { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) return @@ -935,6 +935,192 @@ func TestListRules(t *testing.T) { } } +func TestListUserRules(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + userID := generateUUID(t) + otherUserID := generateUUID(t) + channelID := generateUUID(t) + + // Create 10 rules; assign the first 4 to userID via a role. + var allRules []re.Rule + for i := range 10 { + r := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: channelID, + Logic: re.Script{Type: re.LuaType, Value: "return true"}, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), r) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + allRules = append(allRules, rule) + } + + // Assign userID to the first 4 rules via direct role INSERT. + for i := range 4 { + roleID := generateUUID(t) + _, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allRules[i].ID) + assert.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err)) + _, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allRules[i].ID) + assert.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err)) + } + + cases := []struct { + desc string + userID string + pm re.PageMeta + count int + err error + }{ + { + desc: "list user rules returns only accessible rules", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with offset", + userID: userID, + pm: re.PageMeta{ + Offset: 2, + Limit: 100, + Status: re.AllStatus, + }, + count: 2, + err: nil, + }, + { + desc: "list user rules with limit", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 2, + Status: re.AllStatus, + }, + count: 2, + err: nil, + }, + { + desc: "list user rules with domain filter", + userID: userID, + pm: re.PageMeta{ + Domain: domainID, + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with channel filter", + userID: userID, + pm: re.PageMeta{ + InputChannel: channelID, + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with non-existing domain returns 0", + userID: userID, + pm: re.PageMeta{ + Domain: generateUUID(t), + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list rules for user with no role assignments returns 0", + userID: otherUserID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list user rules ordered by name ascending", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + Order: nameOrder, + Dir: ascDir, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules ordered by created_at descending", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: descDir, + }, + count: 4, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListUserRules(context.Background(), tc.userID, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules))) + if len(page.Rules) > 1 { + switch tc.pm.Order { + case nameOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name <= page.Rules[j].Name + }), "Expected names to be sorted ascending") + } + case createdAtOrder: + if tc.pm.Dir == descDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted descending") + } + } + } + }) + } +} + func TestRemoveRule(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM rules") diff --git a/re/rule.go b/re/rule.go index 349ee21f6..bd4c7b2c7 100644 --- a/re/rule.go +++ b/re/rule.go @@ -172,6 +172,7 @@ type PageMeta struct { ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time Recurring *schedule.Recurring `json:"recurring,omitempty" db:"recurring"` // Filter by recurring type + UserID string `json:"user_id,omitempty" db:"user_id"` } // EventEncode converts a PageMeta struct to map[string]any. @@ -250,7 +251,8 @@ type Repository interface { UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error) RemoveRule(ctx context.Context, id string) error UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error) - ListRules(ctx context.Context, pm PageMeta) (Page, error) + ListAllRules(ctx context.Context, pm PageMeta) (Page, error) + ListUserRules(ctx context.Context, userID string, pm PageMeta) (Page, error) UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error) roles.Repository } diff --git a/re/service.go b/re/service.go index 93ad12395..9a52cc427 100644 --- a/re/service.go +++ b/re/service.go @@ -175,7 +175,14 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R func (re *re) ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) { pm.Domain = session.DomainID - page, err := re.repo.ListRules(ctx, pm) + if session.SuperAdmin { + page, err := re.repo.ListAllRules(ctx, pm) + if err != nil { + return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil + } + page, err := re.repo.ListUserRules(ctx, session.UserID, pm) if err != nil { return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) } diff --git a/re/service_test.go b/re/service_test.go index 09081435d..e9694bb3a 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -868,11 +868,12 @@ func TestListRules(t *testing.T) { } cases := []struct { - desc string - session authn.Session - pageMeta re.PageMeta - res re.Page - err error + desc string + session authn.Session + pageMeta re.PageMeta + res re.Page + err error + superAdmin bool }{ { desc: "list rules successfully", @@ -948,11 +949,44 @@ func TestListRules(t *testing.T) { pageMeta: re.PageMeta{}, err: svcerr.ErrViewEntity, }, + { + desc: "list rules as super admin successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: re.PageMeta{}, + res: re.Page{ + Total: uint64(numRules), + Offset: 0, + Limit: 10, + Rules: rules[0:10], + }, + superAdmin: true, + err: nil, + }, + { + desc: "list rules as super admin with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: re.PageMeta{}, + superAdmin: true, + err: svcerr.ErrViewEntity, + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.res, tc.err) + var repoCall *mock.Call + if tc.superAdmin { + repoCall = repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.res, tc.err) + } else { + repoCall = repo.On("ListUserRules", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err) + } res, err := svc.ListRules(context.Background(), tc.session, tc.pageMeta) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) @@ -1778,7 +1812,7 @@ func TestHandle(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { var err error - repoCall := repo.On("ListRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) { + repoCall := repo.On("ListAllRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) { if tc.listErr != nil { err = tc.listErr } @@ -1854,7 +1888,7 @@ func TestStartScheduler(t *testing.T) { for _, tc := range ctxCases { t.Run(tc.desc, func(t *testing.T) { - repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr) + repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr) tickChan := make(chan time.Time) tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan)) tickCall1 := ticker.On("Stop").Return() @@ -1989,7 +2023,7 @@ func TestStartScheduler(t *testing.T) { Total: uint64(len(tc.rules)), } - repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(page, tc.listErr) + repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(page, tc.listErr) repoCall2 := repo.On("UpdateRuleDue", mock.Anything, mock.Anything, mock.Anything).Return(re.Rule{}, tc.updateDueErr) tickChan := make(chan time.Time, 1) tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan)) diff --git a/reports/handler.go b/reports/handler.go index 64a54f3a1..e97a5042e 100644 --- a/reports/handler.go +++ b/reports/handler.go @@ -27,7 +27,7 @@ func (r *report) StartScheduler(ctx context.Context) error { ScheduledBefore: &due, } - reportConfigs, err := r.repo.ListReportsConfig(ctx, pm) + reportConfigs, err := r.repo.ListAllReportsConfig(ctx, pm) if err != nil { r.runInfo <- pkglog.RunInfo{ Level: slog.LevelError, diff --git a/reports/middleware/authorization.go b/reports/middleware/authorization.go index 7be992af5..fd1434b72 100644 --- a/reports/middleware/authorization.go +++ b/reports/middleware/authorization.go @@ -11,6 +11,7 @@ import ( "github.com/absmach/supermq/pkg/authn" smqauthz "github.com/absmach/supermq/pkg/authz" "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/permissions" "github.com/absmach/supermq/pkg/policies" rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" @@ -93,8 +94,12 @@ func (am *authorizationMiddleware) RemoveReportConfig(ctx context.Context, sessi } func (am *authorizationMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { - if err := am.authorize(ctx, operations.OpListReportsConfig, session, policies.DomainType, session.DomainID); err != nil { - return reports.ReportConfigPage{}, errors.Wrap(errDomainViewConfigs, err) + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: + session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return reports.ReportConfigPage{}, err } return am.svc.ListReportsConfig(ctx, session, pm) @@ -187,3 +192,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions return nil } + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error { + if session.Role != authn.SuperAdminRole { + return svcerr.ErrSuperAdminAction + } + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + SubjectType: policies.UserType, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.SuperMQObject, + }, nil); err != nil { + return err + } + return nil +} diff --git a/reports/mocks/repository.go b/reports/mocks/repository.go index 8d9e1905d..2ed50b3e4 100644 --- a/reports/mocks/repository.go +++ b/reports/mocks/repository.go @@ -241,6 +241,72 @@ func (_c *Repository_DeleteReportTemplate_Call) RunAndReturn(run func(ctx contex return _c } +// ListAllReportsConfig provides a mock function for the type Repository +func (_mock *Repository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAllReportsConfig") + } + + var r0 reports.ReportConfigPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(reports.ReportConfigPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAllReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllReportsConfig' +type Repository_ListAllReportsConfig_Call struct { + *mock.Call +} + +// ListAllReportsConfig is a helper method to define mock.On call +// - ctx context.Context +// - pm reports.PageMeta +func (_e *Repository_Expecter) ListAllReportsConfig(ctx interface{}, pm interface{}) *Repository_ListAllReportsConfig_Call { + return &Repository_ListAllReportsConfig_Call{Call: _e.mock.On("ListAllReportsConfig", ctx, pm)} +} + +func (_c *Repository_ListAllReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListAllReportsConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.PageMeta + if args[1] != nil { + arg1 = args[1].(reports.PageMeta) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListAllReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListAllReportsConfig_Call { + _c.Call.Return(reportConfigPage, err) + return _c +} + +func (_c *Repository_ListAllReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListAllReportsConfig_Call { + _c.Call.Return(run) + return _c +} + // ListEntityMembers provides a mock function for the type Repository func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) { ret := _mock.Called(ctx, entityID, pageQuery) @@ -313,68 +379,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C return _c } -// ListReportsConfig provides a mock function for the type Repository -func (_mock *Repository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { - ret := _mock.Called(ctx, pm) +// ListUserReportsConfig provides a mock function for the type Repository +func (_mock *Repository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ret := _mock.Called(ctx, userID, pm) if len(ret) == 0 { - panic("no return value specified for ListReportsConfig") + panic("no return value specified for ListUserReportsConfig") } var r0 reports.ReportConfigPage var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok { - return returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) (reports.ReportConfigPage, error)); ok { + return returnFunc(ctx, userID, pm) } - if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok { - r0 = returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) reports.ReportConfigPage); ok { + r0 = returnFunc(ctx, userID, pm) } else { r0 = ret.Get(0).(reports.ReportConfigPage) } - if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok { - r1 = returnFunc(ctx, pm) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, reports.PageMeta) error); ok { + r1 = returnFunc(ctx, userID, pm) } else { r1 = ret.Error(1) } return r0, r1 } -// Repository_ListReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListReportsConfig' -type Repository_ListReportsConfig_Call struct { +// Repository_ListUserReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserReportsConfig' +type Repository_ListUserReportsConfig_Call struct { *mock.Call } -// ListReportsConfig is a helper method to define mock.On call +// ListUserReportsConfig is a helper method to define mock.On call // - ctx context.Context +// - userID string // - pm reports.PageMeta -func (_e *Repository_Expecter) ListReportsConfig(ctx interface{}, pm interface{}) *Repository_ListReportsConfig_Call { - return &Repository_ListReportsConfig_Call{Call: _e.mock.On("ListReportsConfig", ctx, pm)} +func (_e *Repository_Expecter) ListUserReportsConfig(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserReportsConfig_Call { + return &Repository_ListUserReportsConfig_Call{Call: _e.mock.On("ListUserReportsConfig", ctx, userID, pm)} } -func (_c *Repository_ListReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListReportsConfig_Call { +func (_c *Repository_ListUserReportsConfig_Call) Run(run func(ctx context.Context, userID string, pm reports.PageMeta)) *Repository_ListUserReportsConfig_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 reports.PageMeta + var arg1 string if args[1] != nil { - arg1 = args[1].(reports.PageMeta) + arg1 = args[1].(string) + } + var arg2 reports.PageMeta + if args[2] != nil { + arg2 = args[2].(reports.PageMeta) } run( arg0, arg1, + arg2, ) }) return _c } -func (_c *Repository_ListReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListReportsConfig_Call { +func (_c *Repository_ListUserReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListUserReportsConfig_Call { _c.Call.Return(reportConfigPage, err) return _c } -func (_c *Repository_ListReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListReportsConfig_Call { +func (_c *Repository_ListUserReportsConfig_Call) RunAndReturn(run func(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListUserReportsConfig_Call { _c.Call.Return(run) return _c } diff --git a/reports/postgres/repository.go b/reports/postgres/repository.go index 0a5b2fc2c..6b41f7616 100644 --- a/reports/postgres/repository.go +++ b/reports/postgres/repository.go @@ -405,39 +405,16 @@ func (repo *PostgresRepository) RemoveReportConfig(ctx context.Context, id strin return nil } -func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { +func (repo *PostgresRepository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { listReportsQuery := ` SELECT id, name, description, domain_id, metrics, email, config, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status FROM report_config rc %s %s %s; ` - pgData := "" - if pm.Limit != 0 { - pgData = "LIMIT :limit" - } - if pm.Offset != 0 { - pgData += " OFFSET :offset" - } pq := pageReportQuery(pm) - - dir := api.DescDir - if pm.Dir == api.AscDir { - dir = api.AscDir - } - - orderClause := "" - - switch pm.Order { - case api.NameKey: - orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) - case api.CreatedAtOrder: - orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) - case api.UpdatedAtOrder: - orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) - default: - orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) - } + orderClause := reportsOrderClause(pm) + pgData := reportsPageData(pm) q := fmt.Sprintf(listReportsQuery, pq, orderClause, pgData) rows, err := repo.DB.NamedQueryContext(ctx, q, pm) @@ -474,6 +451,63 @@ func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm report return ret, nil } +func (repo *PostgresRepository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) { + pq := pageReportQuery(pm) + orderClause := reportsOrderClause(pm) + pgData := reportsPageData(pm) + + pm.UserID = userID + userJoin := ` + INNER JOIN reports_roles rr ON rr.entity_id = rc.id + INNER JOIN reports_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id + ` + + whereClause := pq + + innerQ := fmt.Sprintf(` + SELECT DISTINCT rc.id, rc.name, rc.description, rc.domain_id, rc.metrics, rc.email, rc.config, + rc.start_datetime, rc.due, rc.recurring, rc.recurring_period, rc.created_at, rc.created_by, rc.updated_at, rc.updated_by, rc.status + FROM report_config rc + %s + %s + `, userJoin, whereClause) + + q := fmt.Sprintf(` + SELECT * FROM (%s) AS sub %s %s; + `, innerQ, orderClause, pgData) + + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return reports.ReportConfigPage{}, err + } + defer rows.Close() + + cfgs := []reports.ReportConfig{} + for rows.Next() { + var r dbReport + if err := rows.StructScan(&r); err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + rpt, err := dbToReport(r) + if err != nil { + return reports.ReportConfigPage{}, err + } + cfgs = append(cfgs, rpt) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ) + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + pm.Total = total + + return reports.ReportConfigPage{ + PageMeta: pm, + ReportConfigs: cfgs, + }, nil +} + func (repo *PostgresRepository) UpdateReportDue(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error) { q := ` UPDATE report_config @@ -576,6 +610,32 @@ func (repo *PostgresRepository) DeleteReportTemplate(ctx context.Context, domain return nil } +func reportsOrderClause(pm reports.PageMeta) string { + dir := api.DescDir + if pm.Dir == api.AscDir { + dir = api.AscDir + } + switch pm.Order { + case api.NameKey: + return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) + case api.CreatedAtOrder: + return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) + default: + return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) + } +} + +func reportsPageData(pm reports.PageMeta) string { + pgData := "" + if pm.Limit != 0 { + pgData = "LIMIT :limit" + } + if pm.Offset != 0 { + pgData += " OFFSET :offset" + } + return pgData +} + func pageReportQuery(pm reports.PageMeta) string { var query []string if pm.Status != reports.AllStatus { diff --git a/reports/postgres/repository_test.go b/reports/postgres/repository_test.go index 5ea2a2166..84a16f303 100644 --- a/reports/postgres/repository_test.go +++ b/reports/postgres/repository_test.go @@ -472,7 +472,7 @@ func TestListReportsConfig(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - page, err := repo.ListReportsConfig(context.Background(), tc.pageMeta) + page, err := repo.ListAllReportsConfig(context.Background(), tc.pageMeta) if tc.err != nil { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) return @@ -483,6 +483,132 @@ func TestListReportsConfig(t *testing.T) { } } +func TestListUserReportsConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + userID := generateUUID(t) + otherUserID := generateUUID(t) + + num := 10 + var allCfgs []reports.ReportConfig + for i := range num { + cfg := reports.ReportConfig{ + ID: generateUUID(t), + Name: fmt.Sprintf("Report-%d", i), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute), + Metrics: []reports.ReqMetric{}, + } + cfg, err := repo.AddReportConfig(context.Background(), cfg) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + allCfgs = append(allCfgs, cfg) + } + + // Assign userID to the first 5 report configs via direct role INSERT. + for i := range 5 { + roleID := generateUUID(t) + _, err := db.Exec(`INSERT INTO reports_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allCfgs[i].ID) + require.Nil(t, err, fmt.Sprintf("insert reports_roles unexpected error: %s", err)) + _, err = db.Exec(`INSERT INTO reports_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allCfgs[i].ID) + require.Nil(t, err, fmt.Sprintf("insert reports_role_members unexpected error: %s", err)) + } + + cases := []struct { + desc string + userID string + pageMeta reports.PageMeta + size int + err error + }{ + { + desc: "list user reports returns only accessible reports", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 0, + }, + size: 5, + err: nil, + }, + { + desc: "list user reports with limit", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 3, + Offset: 0, + }, + size: 3, + err: nil, + }, + { + desc: "list user reports with offset", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 3, + }, + size: 2, + err: nil, + }, + { + desc: "list user reports with enabled status filter", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Status: reports.EnabledStatus, + }, + size: 5, + err: nil, + }, + { + desc: "list reports for user with no role assignments returns 0", + userID: otherUserID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 0, + }, + size: 0, + err: nil, + }, + { + desc: "list user reports with non-existing domain returns 0", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: generateUUID(t), + Limit: 100, + Offset: 0, + }, + size: 0, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListUserReportsConfig(context.Background(), tc.userID, tc.pageMeta) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.size, len(page.ReportConfigs), fmt.Sprintf("%s: expected %d reports, got %d", tc.desc, tc.size, len(page.ReportConfigs))) + }) + } +} + func TestUpdateReportSchedule(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM report_config") diff --git a/reports/reports.go b/reports/reports.go index 495f55365..d1cab626b 100644 --- a/reports/reports.go +++ b/reports/reports.go @@ -395,6 +395,7 @@ type PageMeta struct { Domain string `json:"domain_id,omitempty" db:"domain_id"` ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time + UserID string `json:"user_id,omitempty" db:"user_id"` } type Repository interface { @@ -405,7 +406,8 @@ type Repository interface { UpdateReportSchedule(ctx context.Context, cfg ReportConfig) (ReportConfig, error) RemoveReportConfig(ctx context.Context, id string) error UpdateReportConfigStatus(ctx context.Context, cfg ReportConfig) (ReportConfig, error) - ListReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error) + ListAllReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error) + ListUserReportsConfig(ctx context.Context, userID string, pm PageMeta) (ReportConfigPage, error) UpdateReportDue(ctx context.Context, id string, due time.Time) (ReportConfig, error) UpdateReportTemplate(ctx context.Context, domainID, reportID string, template ReportTemplate) error diff --git a/reports/service.go b/reports/service.go index f1667c277..c8db494af 100644 --- a/reports/service.go +++ b/reports/service.go @@ -159,7 +159,14 @@ func (r *report) RemoveReportConfig(ctx context.Context, session authn.Session, func (r *report) ListReportsConfig(ctx context.Context, session authn.Session, pm PageMeta) (ReportConfigPage, error) { pm.Domain = session.DomainID - page, err := r.repo.ListReportsConfig(ctx, pm) + if session.SuperAdmin { + page, err := r.repo.ListAllReportsConfig(ctx, pm) + if err != nil { + return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil + } + page, err := r.repo.ListUserReportsConfig(ctx, session.UserID, pm) if err != nil { return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err) } diff --git a/reports/service_test.go b/reports/service_test.go index 6968a235a..d49619a92 100644 --- a/reports/service_test.go +++ b/reports/service_test.go @@ -328,11 +328,12 @@ func TestListReportsConfig(t *testing.T) { } cases := []struct { - desc string - session authn.Session - pageMeta reports.PageMeta - res reports.ReportConfigPage - err error + desc string + session authn.Session + pageMeta reports.PageMeta + res reports.ReportConfigPage + err error + superAdmin bool }{ { desc: "list report configs successfully", @@ -399,11 +400,46 @@ func TestListReportsConfig(t *testing.T) { pageMeta: reports.PageMeta{}, err: svcerr.ErrViewEntity, }, + { + desc: "list report configs as super admin successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: reports.PageMeta{}, + res: reports.ReportConfigPage{ + PageMeta: reports.PageMeta{ + Total: uint64(numConfigs), + Offset: 0, + Limit: 10, + }, + ReportConfigs: configs[0:10], + }, + superAdmin: true, + err: nil, + }, + { + desc: "list report configs as super admin with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: reports.PageMeta{}, + superAdmin: true, + err: svcerr.ErrViewEntity, + }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := repo.On("ListReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + var repoCall *mock.Call + if tc.superAdmin { + repoCall = repo.On("ListAllReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + } else { + repoCall = repo.On("ListUserReportsConfig", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err) + } res, err := svc.ListReportsConfig(context.Background(), tc.session, tc.pageMeta) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil {