mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 04:10:34 +00:00
NOISSUE - Refactor listing for rules and reports (#433)
* add access control to rules engine Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * update authorization method Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * revert code Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * initial implementation Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove domain from method Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix failing linter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix userid parameter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * update checksuperadmin method Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * revert changes Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * address comments Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> --------- Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
+2
-2
@@ -43,7 +43,7 @@ func (re *re) Handle(msg *messaging.Message) error {
|
||||
Scheduled: &scheduledFalse,
|
||||
}
|
||||
ctx := context.Background()
|
||||
page, err := re.repo.ListRules(ctx, pm)
|
||||
page, err := re.repo.ListAllRules(ctx, pm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -130,7 +130,7 @@ func (re *re) StartScheduler(ctx context.Context) error {
|
||||
ScheduledBefore: &due,
|
||||
}
|
||||
|
||||
page, err := re.repo.ListRules(ctx, pm)
|
||||
page, err := re.repo.ListAllRules(ctx, pm)
|
||||
if err != nil {
|
||||
re.runInfo <- pkglog.RunInfo{
|
||||
Level: slog.LevelError,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
@@ -90,8 +91,12 @@ func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, sessi
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) {
|
||||
if err := am.authorize(ctx, operations.OpListRules, session, policies.DomainType, session.DomainID); err != nil {
|
||||
return re.Page{}, errors.Wrap(errDomainViewRules, err)
|
||||
switch err := am.checkSuperAdmin(ctx, session); {
|
||||
case err == nil:
|
||||
session.SuperAdmin = true
|
||||
case errors.Contains(err, svcerr.ErrSuperAdminAction):
|
||||
default:
|
||||
return re.Page{}, err
|
||||
}
|
||||
|
||||
return am.svc.ListRules(ctx, session, pm)
|
||||
@@ -168,3 +173,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
|
||||
if session.Role != authn.SuperAdminRole {
|
||||
return svcerr.ErrSuperAdminAction
|
||||
}
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.UserID,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.SuperMQObject,
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+92
-20
@@ -178,6 +178,72 @@ func (_c *Repository_AddRule_Call) RunAndReturn(run func(ctx context.Context, r
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListAllRules provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
|
||||
ret := _mock.Called(ctx, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListAllRules")
|
||||
}
|
||||
|
||||
var r0 re.Page
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok {
|
||||
return returnFunc(ctx, pm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok {
|
||||
r0 = returnFunc(ctx, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(re.Page)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_ListAllRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllRules'
|
||||
type Repository_ListAllRules_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListAllRules is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - pm re.PageMeta
|
||||
func (_e *Repository_Expecter) ListAllRules(ctx interface{}, pm interface{}) *Repository_ListAllRules_Call {
|
||||
return &Repository_ListAllRules_Call{Call: _e.mock.On("ListAllRules", ctx, pm)}
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListAllRules_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 re.PageMeta
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(re.PageMeta)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllRules_Call) Return(page re.Page, err error) *Repository_ListAllRules_Call {
|
||||
_c.Call.Return(page, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListAllRules_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListEntityMembers provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
|
||||
ret := _mock.Called(ctx, entityID, pageQuery)
|
||||
@@ -250,68 +316,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListRules provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
|
||||
ret := _mock.Called(ctx, pm)
|
||||
// ListUserRules provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) {
|
||||
ret := _mock.Called(ctx, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListRules")
|
||||
panic("no return value specified for ListUserRules")
|
||||
}
|
||||
|
||||
var r0 re.Page
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok {
|
||||
return returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) (re.Page, error)); ok {
|
||||
return returnFunc(ctx, userID, pm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok {
|
||||
r0 = returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) re.Page); ok {
|
||||
r0 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(re.Page)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, re.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_ListRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRules'
|
||||
type Repository_ListRules_Call struct {
|
||||
// Repository_ListUserRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRules'
|
||||
type Repository_ListUserRules_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListRules is a helper method to define mock.On call
|
||||
// ListUserRules is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - pm re.PageMeta
|
||||
func (_e *Repository_Expecter) ListRules(ctx interface{}, pm interface{}) *Repository_ListRules_Call {
|
||||
return &Repository_ListRules_Call{Call: _e.mock.On("ListRules", ctx, pm)}
|
||||
func (_e *Repository_Expecter) ListUserRules(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserRules_Call {
|
||||
return &Repository_ListUserRules_Call{Call: _e.mock.On("ListUserRules", ctx, userID, pm)}
|
||||
}
|
||||
|
||||
func (_c *Repository_ListRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListRules_Call {
|
||||
func (_c *Repository_ListUserRules_Call) Run(run func(ctx context.Context, userID string, pm re.PageMeta)) *Repository_ListUserRules_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 re.PageMeta
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(re.PageMeta)
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 re.PageMeta
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(re.PageMeta)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListRules_Call) Return(page re.Page, err error) *Repository_ListRules_Call {
|
||||
func (_c *Repository_ListUserRules_Call) Return(page re.Page, err error) *Repository_ListUserRules_Call {
|
||||
_c.Call.Return(page, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListRules_Call {
|
||||
func (_c *Repository_ListUserRules_Call) RunAndReturn(run func(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error)) *Repository_ListUserRules_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
+86
-26
@@ -357,33 +357,10 @@ func (repo *PostgresRepository) RemoveRule(ctx context.Context, id string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
|
||||
pgData := ""
|
||||
if pm.Limit != 0 {
|
||||
pgData = "LIMIT :limit"
|
||||
}
|
||||
if pm.Offset != 0 {
|
||||
pgData += " OFFSET :offset"
|
||||
}
|
||||
func (repo *PostgresRepository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) {
|
||||
pq := pageRulesQuery(pm)
|
||||
|
||||
dir := api.DescDir
|
||||
if pm.Dir == api.AscDir {
|
||||
dir = api.AscDir
|
||||
}
|
||||
|
||||
orderClause := ""
|
||||
|
||||
switch pm.Order {
|
||||
case api.NameKey:
|
||||
orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
|
||||
case api.CreatedAtOrder:
|
||||
orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
|
||||
case api.UpdatedAtOrder:
|
||||
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
default:
|
||||
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
}
|
||||
orderClause := rulesOrderClause(pm)
|
||||
pgData := rulesPageData(pm)
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs,
|
||||
@@ -425,6 +402,62 @@ func (repo *PostgresRepository) ListRules(ctx context.Context, pm re.PageMeta) (
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) {
|
||||
pm.UserID = userID
|
||||
pq := pageRulesQuery(pm)
|
||||
orderClause := rulesOrderClause(pm)
|
||||
pgData := rulesPageData(pm)
|
||||
|
||||
userJoin := `
|
||||
INNER JOIN rules_roles rr ON rr.entity_id = r.id
|
||||
INNER JOIN rules_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id
|
||||
`
|
||||
|
||||
innerQ := fmt.Sprintf(`
|
||||
SELECT DISTINCT r.id, r.name, r.domain_id, r.tags, r.input_channel, r.input_topic, r.logic_type, r.logic_value, r.outputs,
|
||||
r.start_datetime, r.time, r.recurring, r.recurring_period, r.created_at, r.created_by, r.updated_at, r.updated_by, r.status
|
||||
FROM rules r
|
||||
%s
|
||||
%s
|
||||
`, userJoin, pq)
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
SELECT * FROM (%s) AS sub %s %s;
|
||||
`, innerQ, orderClause, pgData)
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
|
||||
if err != nil {
|
||||
return re.Page{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rules []re.Rule
|
||||
for rows.Next() {
|
||||
var r dbRule
|
||||
if err := rows.StructScan(&r); err != nil {
|
||||
return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
ret, err := dbToRule(r)
|
||||
if err != nil {
|
||||
return re.Page{}, err
|
||||
}
|
||||
rules = append(rules, ret)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ)
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, pm)
|
||||
if err != nil {
|
||||
return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return re.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
Rules: rules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, due time.Time) (re.Rule, error) {
|
||||
q := `
|
||||
UPDATE rules
|
||||
@@ -460,6 +493,33 @@ func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, du
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func rulesOrderClause(pm re.PageMeta) string {
|
||||
dir := api.DescDir
|
||||
if pm.Dir == api.AscDir {
|
||||
dir = api.AscDir
|
||||
}
|
||||
|
||||
switch pm.Order {
|
||||
case api.NameKey:
|
||||
return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
|
||||
case api.CreatedAtOrder:
|
||||
return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
|
||||
default:
|
||||
return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func rulesPageData(pm re.PageMeta) string {
|
||||
pgData := ""
|
||||
if pm.Limit != 0 {
|
||||
pgData = "LIMIT :limit"
|
||||
}
|
||||
if pm.Offset != 0 {
|
||||
pgData += " OFFSET :offset"
|
||||
}
|
||||
return pgData
|
||||
}
|
||||
|
||||
func pageRulesQuery(pm re.PageMeta) string {
|
||||
var query []string
|
||||
if pm.InputChannel != "" {
|
||||
|
||||
@@ -890,7 +890,7 @@ func TestListRules(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
page, err := repo.ListRules(context.Background(), tc.pm)
|
||||
page, err := repo.ListAllRules(context.Background(), tc.pm)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
return
|
||||
@@ -935,6 +935,192 @@ func TestListRules(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserRules(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM rules")
|
||||
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
domainID := generateUUID(t)
|
||||
userID := generateUUID(t)
|
||||
otherUserID := generateUUID(t)
|
||||
channelID := generateUUID(t)
|
||||
|
||||
// Create 10 rules; assign the first 4 to userID via a role.
|
||||
var allRules []re.Rule
|
||||
for i := range 10 {
|
||||
r := re.Rule{
|
||||
ID: generateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
DomainID: domainID,
|
||||
InputChannel: channelID,
|
||||
Logic: re.Script{Type: re.LuaType, Value: "return true"},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
|
||||
CreatedBy: generateUUID(t),
|
||||
UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
|
||||
UpdatedBy: generateUUID(t),
|
||||
}
|
||||
rule, err := repo.AddRule(context.Background(), r)
|
||||
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
allRules = append(allRules, rule)
|
||||
}
|
||||
|
||||
// Assign userID to the first 4 rules via direct role INSERT.
|
||||
for i := range 4 {
|
||||
roleID := generateUUID(t)
|
||||
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allRules[i].ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
|
||||
_, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allRules[i].ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
pm re.PageMeta
|
||||
count int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list user rules returns only accessible rules",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 4,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules with offset",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 2,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules with limit",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 0,
|
||||
Limit: 2,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules with domain filter",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Domain: domainID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 4,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules with channel filter",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
InputChannel: channelID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 4,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules with non-existing domain returns 0",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Domain: generateUUID(t),
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 0,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules for user with no role assignments returns 0",
|
||||
userID: otherUserID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
},
|
||||
count: 0,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules ordered by name ascending",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
Order: nameOrder,
|
||||
Dir: ascDir,
|
||||
},
|
||||
count: 4,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user rules ordered by created_at descending",
|
||||
userID: userID,
|
||||
pm: re.PageMeta{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: re.AllStatus,
|
||||
Order: createdAtOrder,
|
||||
Dir: descDir,
|
||||
},
|
||||
count: 4,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
page, err := repo.ListUserRules(context.Background(), tc.userID, tc.pm)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
return
|
||||
}
|
||||
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules)))
|
||||
if len(page.Rules) > 1 {
|
||||
switch tc.pm.Order {
|
||||
case nameOrder:
|
||||
if tc.pm.Dir == ascDir {
|
||||
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
|
||||
return page.Rules[i].Name <= page.Rules[j].Name
|
||||
}), "Expected names to be sorted ascending")
|
||||
}
|
||||
case createdAtOrder:
|
||||
if tc.pm.Dir == descDir {
|
||||
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
|
||||
return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt)
|
||||
}), "Expected created_at to be sorted descending")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRule(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM rules")
|
||||
|
||||
+3
-1
@@ -172,6 +172,7 @@ type PageMeta struct {
|
||||
ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time
|
||||
ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time
|
||||
Recurring *schedule.Recurring `json:"recurring,omitempty" db:"recurring"` // Filter by recurring type
|
||||
UserID string `json:"user_id,omitempty" db:"user_id"`
|
||||
}
|
||||
|
||||
// EventEncode converts a PageMeta struct to map[string]any.
|
||||
@@ -250,7 +251,8 @@ type Repository interface {
|
||||
UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error)
|
||||
RemoveRule(ctx context.Context, id string) error
|
||||
UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error)
|
||||
ListRules(ctx context.Context, pm PageMeta) (Page, error)
|
||||
ListAllRules(ctx context.Context, pm PageMeta) (Page, error)
|
||||
ListUserRules(ctx context.Context, userID string, pm PageMeta) (Page, error)
|
||||
UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error)
|
||||
roles.Repository
|
||||
}
|
||||
|
||||
+8
-1
@@ -175,7 +175,14 @@ func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r R
|
||||
|
||||
func (re *re) ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) {
|
||||
pm.Domain = session.DomainID
|
||||
page, err := re.repo.ListRules(ctx, pm)
|
||||
if session.SuperAdmin {
|
||||
page, err := re.repo.ListAllRules(ctx, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return page, nil
|
||||
}
|
||||
page, err := re.repo.ListUserRules(ctx, session.UserID, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
+43
-9
@@ -868,11 +868,12 @@ func TestListRules(t *testing.T) {
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta re.PageMeta
|
||||
res re.Page
|
||||
err error
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta re.PageMeta
|
||||
res re.Page
|
||||
err error
|
||||
superAdmin bool
|
||||
}{
|
||||
{
|
||||
desc: "list rules successfully",
|
||||
@@ -948,11 +949,44 @@ func TestListRules(t *testing.T) {
|
||||
pageMeta: re.PageMeta{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "list rules as super admin successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
SuperAdmin: true,
|
||||
},
|
||||
pageMeta: re.PageMeta{},
|
||||
res: re.Page{
|
||||
Total: uint64(numRules),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Rules: rules[0:10],
|
||||
},
|
||||
superAdmin: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules as super admin with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
SuperAdmin: true,
|
||||
},
|
||||
pageMeta: re.PageMeta{},
|
||||
superAdmin: true,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
var repoCall *mock.Call
|
||||
if tc.superAdmin {
|
||||
repoCall = repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
} else {
|
||||
repoCall = repo.On("ListUserRules", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
}
|
||||
res, err := svc.ListRules(context.Background(), tc.session, tc.pageMeta)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
@@ -1778,7 +1812,7 @@ func TestHandle(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
var err error
|
||||
|
||||
repoCall := repo.On("ListRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) {
|
||||
repoCall := repo.On("ListAllRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) {
|
||||
if tc.listErr != nil {
|
||||
err = tc.listErr
|
||||
}
|
||||
@@ -1854,7 +1888,7 @@ func TestStartScheduler(t *testing.T) {
|
||||
|
||||
for _, tc := range ctxCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr)
|
||||
repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr)
|
||||
tickChan := make(chan time.Time)
|
||||
tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan))
|
||||
tickCall1 := ticker.On("Stop").Return()
|
||||
@@ -1989,7 +2023,7 @@ func TestStartScheduler(t *testing.T) {
|
||||
Total: uint64(len(tc.rules)),
|
||||
}
|
||||
|
||||
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(page, tc.listErr)
|
||||
repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(page, tc.listErr)
|
||||
repoCall2 := repo.On("UpdateRuleDue", mock.Anything, mock.Anything, mock.Anything).Return(re.Rule{}, tc.updateDueErr)
|
||||
tickChan := make(chan time.Time, 1)
|
||||
tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan))
|
||||
|
||||
+1
-1
@@ -27,7 +27,7 @@ func (r *report) StartScheduler(ctx context.Context) error {
|
||||
ScheduledBefore: &due,
|
||||
}
|
||||
|
||||
reportConfigs, err := r.repo.ListReportsConfig(ctx, pm)
|
||||
reportConfigs, err := r.repo.ListAllReportsConfig(ctx, pm)
|
||||
if err != nil {
|
||||
r.runInfo <- pkglog.RunInfo{
|
||||
Level: slog.LevelError,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/permissions"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
|
||||
@@ -93,8 +94,12 @@ func (am *authorizationMiddleware) RemoveReportConfig(ctx context.Context, sessi
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
if err := am.authorize(ctx, operations.OpListReportsConfig, session, policies.DomainType, session.DomainID); err != nil {
|
||||
return reports.ReportConfigPage{}, errors.Wrap(errDomainViewConfigs, err)
|
||||
switch err := am.checkSuperAdmin(ctx, session); {
|
||||
case err == nil:
|
||||
session.SuperAdmin = true
|
||||
case errors.Contains(err, svcerr.ErrSuperAdminAction):
|
||||
default:
|
||||
return reports.ReportConfigPage{}, err
|
||||
}
|
||||
|
||||
return am.svc.ListReportsConfig(ctx, session, pm)
|
||||
@@ -187,3 +192,19 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
|
||||
if session.Role != authn.SuperAdminRole {
|
||||
return svcerr.ErrSuperAdminAction
|
||||
}
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.UserID,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.SuperMQObject,
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+92
-20
@@ -241,6 +241,72 @@ func (_c *Repository_DeleteReportTemplate_Call) RunAndReturn(run func(ctx contex
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListAllReportsConfig provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
ret := _mock.Called(ctx, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListAllReportsConfig")
|
||||
}
|
||||
|
||||
var r0 reports.ReportConfigPage
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
|
||||
return returnFunc(ctx, pm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok {
|
||||
r0 = returnFunc(ctx, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(reports.ReportConfigPage)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_ListAllReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllReportsConfig'
|
||||
type Repository_ListAllReportsConfig_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListAllReportsConfig is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - pm reports.PageMeta
|
||||
func (_e *Repository_Expecter) ListAllReportsConfig(ctx interface{}, pm interface{}) *Repository_ListAllReportsConfig_Call {
|
||||
return &Repository_ListAllReportsConfig_Call{Call: _e.mock.On("ListAllReportsConfig", ctx, pm)}
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListAllReportsConfig_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 reports.PageMeta
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(reports.PageMeta)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListAllReportsConfig_Call {
|
||||
_c.Call.Return(reportConfigPage, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListAllReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListAllReportsConfig_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListEntityMembers provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
|
||||
ret := _mock.Called(ctx, entityID, pageQuery)
|
||||
@@ -313,68 +379,74 @@ func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.C
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListReportsConfig provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
ret := _mock.Called(ctx, pm)
|
||||
// ListUserReportsConfig provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
ret := _mock.Called(ctx, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListReportsConfig")
|
||||
panic("no return value specified for ListUserReportsConfig")
|
||||
}
|
||||
|
||||
var r0 reports.ReportConfigPage
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
|
||||
return returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) (reports.ReportConfigPage, error)); ok {
|
||||
return returnFunc(ctx, userID, pm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok {
|
||||
r0 = returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) reports.ReportConfigPage); ok {
|
||||
r0 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(reports.ReportConfigPage)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, pm)
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, reports.PageMeta) error); ok {
|
||||
r1 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_ListReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListReportsConfig'
|
||||
type Repository_ListReportsConfig_Call struct {
|
||||
// Repository_ListUserReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserReportsConfig'
|
||||
type Repository_ListUserReportsConfig_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListReportsConfig is a helper method to define mock.On call
|
||||
// ListUserReportsConfig is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - pm reports.PageMeta
|
||||
func (_e *Repository_Expecter) ListReportsConfig(ctx interface{}, pm interface{}) *Repository_ListReportsConfig_Call {
|
||||
return &Repository_ListReportsConfig_Call{Call: _e.mock.On("ListReportsConfig", ctx, pm)}
|
||||
func (_e *Repository_Expecter) ListUserReportsConfig(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserReportsConfig_Call {
|
||||
return &Repository_ListUserReportsConfig_Call{Call: _e.mock.On("ListUserReportsConfig", ctx, userID, pm)}
|
||||
}
|
||||
|
||||
func (_c *Repository_ListReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListReportsConfig_Call {
|
||||
func (_c *Repository_ListUserReportsConfig_Call) Run(run func(ctx context.Context, userID string, pm reports.PageMeta)) *Repository_ListUserReportsConfig_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 reports.PageMeta
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(reports.PageMeta)
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 reports.PageMeta
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(reports.PageMeta)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListReportsConfig_Call {
|
||||
func (_c *Repository_ListUserReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListUserReportsConfig_Call {
|
||||
_c.Call.Return(reportConfigPage, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListReportsConfig_Call {
|
||||
func (_c *Repository_ListUserReportsConfig_Call) RunAndReturn(run func(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListUserReportsConfig_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -405,39 +405,16 @@ func (repo *PostgresRepository) RemoveReportConfig(ctx context.Context, id strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
func (repo *PostgresRepository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
listReportsQuery := `
|
||||
SELECT id, name, description, domain_id, metrics, email, config,
|
||||
start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status
|
||||
FROM report_config rc %s %s %s;
|
||||
`
|
||||
|
||||
pgData := ""
|
||||
if pm.Limit != 0 {
|
||||
pgData = "LIMIT :limit"
|
||||
}
|
||||
if pm.Offset != 0 {
|
||||
pgData += " OFFSET :offset"
|
||||
}
|
||||
pq := pageReportQuery(pm)
|
||||
|
||||
dir := api.DescDir
|
||||
if pm.Dir == api.AscDir {
|
||||
dir = api.AscDir
|
||||
}
|
||||
|
||||
orderClause := ""
|
||||
|
||||
switch pm.Order {
|
||||
case api.NameKey:
|
||||
orderClause = fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
|
||||
case api.CreatedAtOrder:
|
||||
orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
|
||||
case api.UpdatedAtOrder:
|
||||
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
default:
|
||||
orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
}
|
||||
orderClause := reportsOrderClause(pm)
|
||||
pgData := reportsPageData(pm)
|
||||
|
||||
q := fmt.Sprintf(listReportsQuery, pq, orderClause, pgData)
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
|
||||
@@ -474,6 +451,63 @@ func (repo *PostgresRepository) ListReportsConfig(ctx context.Context, pm report
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) {
|
||||
pq := pageReportQuery(pm)
|
||||
orderClause := reportsOrderClause(pm)
|
||||
pgData := reportsPageData(pm)
|
||||
|
||||
pm.UserID = userID
|
||||
userJoin := `
|
||||
INNER JOIN reports_roles rr ON rr.entity_id = rc.id
|
||||
INNER JOIN reports_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id
|
||||
`
|
||||
|
||||
whereClause := pq
|
||||
|
||||
innerQ := fmt.Sprintf(`
|
||||
SELECT DISTINCT rc.id, rc.name, rc.description, rc.domain_id, rc.metrics, rc.email, rc.config,
|
||||
rc.start_datetime, rc.due, rc.recurring, rc.recurring_period, rc.created_at, rc.created_by, rc.updated_at, rc.updated_by, rc.status
|
||||
FROM report_config rc
|
||||
%s
|
||||
%s
|
||||
`, userJoin, whereClause)
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
SELECT * FROM (%s) AS sub %s %s;
|
||||
`, innerQ, orderClause, pgData)
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, pm)
|
||||
if err != nil {
|
||||
return reports.ReportConfigPage{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cfgs := []reports.ReportConfig{}
|
||||
for rows.Next() {
|
||||
var r dbReport
|
||||
if err := rows.StructScan(&r); err != nil {
|
||||
return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
rpt, err := dbToReport(r)
|
||||
if err != nil {
|
||||
return reports.ReportConfigPage{}, err
|
||||
}
|
||||
cfgs = append(cfgs, rpt)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ)
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, pm)
|
||||
if err != nil {
|
||||
return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
pm.Total = total
|
||||
|
||||
return reports.ReportConfigPage{
|
||||
PageMeta: pm,
|
||||
ReportConfigs: cfgs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *PostgresRepository) UpdateReportDue(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error) {
|
||||
q := `
|
||||
UPDATE report_config
|
||||
@@ -576,6 +610,32 @@ func (repo *PostgresRepository) DeleteReportTemplate(ctx context.Context, domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func reportsOrderClause(pm reports.PageMeta) string {
|
||||
dir := api.DescDir
|
||||
if pm.Dir == api.AscDir {
|
||||
dir = api.AscDir
|
||||
}
|
||||
switch pm.Order {
|
||||
case api.NameKey:
|
||||
return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir)
|
||||
case api.CreatedAtOrder:
|
||||
return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir)
|
||||
default:
|
||||
return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func reportsPageData(pm reports.PageMeta) string {
|
||||
pgData := ""
|
||||
if pm.Limit != 0 {
|
||||
pgData = "LIMIT :limit"
|
||||
}
|
||||
if pm.Offset != 0 {
|
||||
pgData += " OFFSET :offset"
|
||||
}
|
||||
return pgData
|
||||
}
|
||||
|
||||
func pageReportQuery(pm reports.PageMeta) string {
|
||||
var query []string
|
||||
if pm.Status != reports.AllStatus {
|
||||
|
||||
@@ -472,7 +472,7 @@ func TestListReportsConfig(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
page, err := repo.ListReportsConfig(context.Background(), tc.pageMeta)
|
||||
page, err := repo.ListAllReportsConfig(context.Background(), tc.pageMeta)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
return
|
||||
@@ -483,6 +483,132 @@ func TestListReportsConfig(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserReportsConfig(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM report_config")
|
||||
require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
domainID := generateUUID(t)
|
||||
userID := generateUUID(t)
|
||||
otherUserID := generateUUID(t)
|
||||
|
||||
num := 10
|
||||
var allCfgs []reports.ReportConfig
|
||||
for i := range num {
|
||||
cfg := reports.ReportConfig{
|
||||
ID: generateUUID(t),
|
||||
Name: fmt.Sprintf("Report-%d", i),
|
||||
DomainID: domainID,
|
||||
Status: reports.EnabledStatus,
|
||||
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
|
||||
UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
|
||||
Metrics: []reports.ReqMetric{},
|
||||
}
|
||||
cfg, err := repo.AddReportConfig(context.Background(), cfg)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
allCfgs = append(allCfgs, cfg)
|
||||
}
|
||||
|
||||
// Assign userID to the first 5 report configs via direct role INSERT.
|
||||
for i := range 5 {
|
||||
roleID := generateUUID(t)
|
||||
_, err := db.Exec(`INSERT INTO reports_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allCfgs[i].ID)
|
||||
require.Nil(t, err, fmt.Sprintf("insert reports_roles unexpected error: %s", err))
|
||||
_, err = db.Exec(`INSERT INTO reports_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allCfgs[i].ID)
|
||||
require.Nil(t, err, fmt.Sprintf("insert reports_role_members unexpected error: %s", err))
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
pageMeta reports.PageMeta
|
||||
size int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list user reports returns only accessible reports",
|
||||
userID: userID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: domainID,
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
size: 5,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user reports with limit",
|
||||
userID: userID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: domainID,
|
||||
Limit: 3,
|
||||
Offset: 0,
|
||||
},
|
||||
size: 3,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user reports with offset",
|
||||
userID: userID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: domainID,
|
||||
Limit: 100,
|
||||
Offset: 3,
|
||||
},
|
||||
size: 2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user reports with enabled status filter",
|
||||
userID: userID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: domainID,
|
||||
Limit: 100,
|
||||
Status: reports.EnabledStatus,
|
||||
},
|
||||
size: 5,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list reports for user with no role assignments returns 0",
|
||||
userID: otherUserID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: domainID,
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
size: 0,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user reports with non-existing domain returns 0",
|
||||
userID: userID,
|
||||
pageMeta: reports.PageMeta{
|
||||
Domain: generateUUID(t),
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
size: 0,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
page, err := repo.ListUserReportsConfig(context.Background(), tc.userID, tc.pageMeta)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
return
|
||||
}
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
require.Equal(t, tc.size, len(page.ReportConfigs), fmt.Sprintf("%s: expected %d reports, got %d", tc.desc, tc.size, len(page.ReportConfigs)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateReportSchedule(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM report_config")
|
||||
|
||||
+3
-1
@@ -395,6 +395,7 @@ type PageMeta struct {
|
||||
Domain string `json:"domain_id,omitempty" db:"domain_id"`
|
||||
ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time
|
||||
ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time
|
||||
UserID string `json:"user_id,omitempty" db:"user_id"`
|
||||
}
|
||||
|
||||
type Repository interface {
|
||||
@@ -405,7 +406,8 @@ type Repository interface {
|
||||
UpdateReportSchedule(ctx context.Context, cfg ReportConfig) (ReportConfig, error)
|
||||
RemoveReportConfig(ctx context.Context, id string) error
|
||||
UpdateReportConfigStatus(ctx context.Context, cfg ReportConfig) (ReportConfig, error)
|
||||
ListReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error)
|
||||
ListAllReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error)
|
||||
ListUserReportsConfig(ctx context.Context, userID string, pm PageMeta) (ReportConfigPage, error)
|
||||
UpdateReportDue(ctx context.Context, id string, due time.Time) (ReportConfig, error)
|
||||
|
||||
UpdateReportTemplate(ctx context.Context, domainID, reportID string, template ReportTemplate) error
|
||||
|
||||
+8
-1
@@ -159,7 +159,14 @@ func (r *report) RemoveReportConfig(ctx context.Context, session authn.Session,
|
||||
|
||||
func (r *report) ListReportsConfig(ctx context.Context, session authn.Session, pm PageMeta) (ReportConfigPage, error) {
|
||||
pm.Domain = session.DomainID
|
||||
page, err := r.repo.ListReportsConfig(ctx, pm)
|
||||
if session.SuperAdmin {
|
||||
page, err := r.repo.ListAllReportsConfig(ctx, pm)
|
||||
if err != nil {
|
||||
return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return page, nil
|
||||
}
|
||||
page, err := r.repo.ListUserReportsConfig(ctx, session.UserID, pm)
|
||||
if err != nil {
|
||||
return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
+42
-6
@@ -328,11 +328,12 @@ func TestListReportsConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta reports.PageMeta
|
||||
res reports.ReportConfigPage
|
||||
err error
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta reports.PageMeta
|
||||
res reports.ReportConfigPage
|
||||
err error
|
||||
superAdmin bool
|
||||
}{
|
||||
{
|
||||
desc: "list report configs successfully",
|
||||
@@ -399,11 +400,46 @@ func TestListReportsConfig(t *testing.T) {
|
||||
pageMeta: reports.PageMeta{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "list report configs as super admin successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
SuperAdmin: true,
|
||||
},
|
||||
pageMeta: reports.PageMeta{},
|
||||
res: reports.ReportConfigPage{
|
||||
PageMeta: reports.PageMeta{
|
||||
Total: uint64(numConfigs),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
ReportConfigs: configs[0:10],
|
||||
},
|
||||
superAdmin: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list report configs as super admin with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
SuperAdmin: true,
|
||||
},
|
||||
pageMeta: reports.PageMeta{},
|
||||
superAdmin: true,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ListReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
var repoCall *mock.Call
|
||||
if tc.superAdmin {
|
||||
repoCall = repo.On("ListAllReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
} else {
|
||||
repoCall = repo.On("ListUserReportsConfig", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
}
|
||||
res, err := svc.ListReportsConfig(context.Background(), tc.session, tc.pageMeta)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
|
||||
Reference in New Issue
Block a user