mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:50:18 +00:00
NOISSUE - Improve SQL queries performance and safety (#3378)
Continuous Delivery / lint-and-build (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
CI Pipeline / Check Certs (push) Has been cancelled
CI Pipeline / Lint Proto (push) Has been cancelled
CI Pipeline / Detect Changes (push) Has been cancelled
Continuous Delivery / Build and Push Docker Images (push) Has been cancelled
CI Pipeline / lint-and-build (push) Has been cancelled
CI Pipeline / Test ${{ matrix.module }} (push) Has been cancelled
CI Pipeline / Upload Coverage (push) Has been cancelled
Continuous Delivery / lint-and-build (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
CI Pipeline / Check Certs (push) Has been cancelled
CI Pipeline / Lint Proto (push) Has been cancelled
CI Pipeline / Detect Changes (push) Has been cancelled
Continuous Delivery / Build and Push Docker Images (push) Has been cancelled
CI Pipeline / lint-and-build (push) Has been cancelled
CI Pipeline / Test ${{ matrix.module }} (push) Has been cancelled
CI Pipeline / Upload Coverage (push) Has been cancelled
Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -134,6 +134,15 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
`ALTER TABLE pat_scopes RENAME COLUMN domain_id TO optional_domain_id;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "auth_8",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_pats_user_id ON pats(user_id);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_pats_user_id;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type dbPat struct {
|
||||
Revoked bool `db:"revoked,omitempty"`
|
||||
RevokedAt sql.NullTime `db:"revoked_at,omitempty"`
|
||||
Status auth.Status `db:"status,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
type dbScope struct {
|
||||
|
||||
+44
-42
@@ -72,14 +72,15 @@ func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSP
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
SELECT
|
||||
SELECT
|
||||
p.id, p.user_id, p.name, p.description, p.issued_at, p.expires_at,
|
||||
p.updated_at, p.revoked, p.revoked_at,
|
||||
CASE
|
||||
CASE
|
||||
WHEN p.revoked = TRUE THEN %d
|
||||
WHEN expires_at IS NOT NULL AND expires_at < :timestamp THEN %d
|
||||
ELSE %d
|
||||
END AS status
|
||||
END AS status,
|
||||
COUNT(*) OVER() AS total_count
|
||||
FROM pats p WHERE user_id = :user_id %s
|
||||
ORDER BY issued_at DESC
|
||||
LIMIT :limit OFFSET :offset`, auth.RevokedStatus, auth.ExpiredStatus, auth.ActiveStatus, pageQuery)
|
||||
@@ -100,30 +101,31 @@ func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSP
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
items := []auth.PAT{}
|
||||
for rows.Next() {
|
||||
var pat dbPat
|
||||
if err := rows.StructScan(&pat); err != nil {
|
||||
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
total = pat.TotalCount
|
||||
items = append(items, toAuthPat(pat))
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM pats p WHERE user_id = :user_id %s`, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, pr.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM pats p WHERE user_id = :user_id %s`, pageQuery)
|
||||
total, err = postgres.Total(ctx, pr.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := auth.PATSPage{
|
||||
return auth.PATSPage{
|
||||
PATS: items,
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
}
|
||||
return page, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func PageQuery(pm auth.PATSPageMeta) (string, error) {
|
||||
@@ -408,14 +410,15 @@ func (pr *patRepo) AddScope(ctx context.Context, userID string, scopes []auth.Sc
|
||||
|
||||
func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope, error) {
|
||||
q := `
|
||||
SELECT COUNT(*)
|
||||
FROM pat_scopes
|
||||
WHERE pat_id = :pat_id
|
||||
AND entity_type = :entity_type
|
||||
AND domain_id = :domain_id
|
||||
AND operation = :operation
|
||||
AND entity_id = :entity_id
|
||||
LIMIT 1`
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pat_scopes
|
||||
WHERE pat_id = :pat_id
|
||||
AND entity_type = :entity_type
|
||||
AND domain_id = :domain_id
|
||||
AND operation = :operation
|
||||
AND entity_id = :entity_id
|
||||
)`
|
||||
|
||||
params := dbScope{
|
||||
PatID: sc.PatID,
|
||||
@@ -431,18 +434,28 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var count int
|
||||
var exists bool
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
if err := rows.Scan(&exists); err != nil {
|
||||
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
if exists {
|
||||
return auth.Scope{}, repoerr.ErrConflict
|
||||
}
|
||||
|
||||
if sc.EntityID == auth.AnyIDs {
|
||||
checkEntityQuery := `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pat_scopes
|
||||
WHERE pat_id = :pat_id
|
||||
AND entity_type = :entity_type
|
||||
AND domain_id = :domain_id
|
||||
AND operation = :operation
|
||||
)`
|
||||
|
||||
newParams := dbScope{
|
||||
PatID: sc.PatID,
|
||||
DomainID: sc.DomainID,
|
||||
@@ -450,42 +463,31 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
|
||||
Operation: sc.Operation,
|
||||
}
|
||||
|
||||
checkEntityQuery := `
|
||||
SELECT COUNT(*)
|
||||
FROM pat_scopes
|
||||
WHERE pat_id = :pat_id
|
||||
AND entity_type = :entity_type
|
||||
AND domain_id = :domain_id
|
||||
AND operation = :operation
|
||||
LIMIT 1`
|
||||
|
||||
rows, err := pr.db.NamedQueryContext(ctx, checkEntityQuery, newParams)
|
||||
if err != nil {
|
||||
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var count int
|
||||
var scopeExists bool
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
if err := rows.Scan(&scopeExists); err != nil {
|
||||
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
if scopeExists {
|
||||
updateWithWildcardQuery := `
|
||||
UPDATE pat_scopes
|
||||
SET entity_id = :entity_id
|
||||
WHERE pat_id = :pat_id
|
||||
UPDATE pat_scopes
|
||||
SET entity_id = :entity_id
|
||||
WHERE pat_id = :pat_id
|
||||
AND entity_type = :entity_type
|
||||
AND domain_id = :domain_id
|
||||
AND operation = :operation`
|
||||
|
||||
rows, err = pr.db.NamedQueryContext(ctx, updateWithWildcardQuery, params)
|
||||
if err != nil {
|
||||
if _, err := pr.db.NamedExecContext(ctx, updateWithWildcardQuery, params); err != nil {
|
||||
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return auth.Scope{}, nil
|
||||
}
|
||||
@@ -657,7 +659,7 @@ func (pr *patRepo) retrievePATFromDB(ctx context.Context, userID, patID string)
|
||||
WHEN revoked = TRUE THEN %d
|
||||
WHEN expires_at IS NOT NULL AND expires_at < :timestamp THEN %d
|
||||
ELSE %d
|
||||
END AS status
|
||||
END AS status
|
||||
FROM pats WHERE user_id = :user_id AND id = :id`, auth.RevokedStatus, auth.ExpiredStatus, auth.ActiveStatus)
|
||||
|
||||
dbp := dbPagemeta{
|
||||
|
||||
+102
-94
@@ -183,7 +183,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
|
||||
SELECT
|
||||
c.id,
|
||||
c.parent_group_id,
|
||||
COALESCE(g."path", ''::::ltree) AS parent_group_path,
|
||||
COALESCE(g."path", CAST('' AS ltree)) AS parent_group_path,
|
||||
c.domain_id
|
||||
FROM
|
||||
channels c
|
||||
@@ -201,7 +201,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
|
||||
cr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT cra."action") AS actions,
|
||||
'direct' AS access_type,
|
||||
''::::ltree AS access_provider_path,
|
||||
CAST('' AS ltree) AS access_provider_path,
|
||||
'' AS access_provider_id
|
||||
FROM
|
||||
channels_roles cr
|
||||
@@ -260,7 +260,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
|
||||
dr.id AS role_id,
|
||||
dr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT all_actions."action") AS actions,
|
||||
''::::ltree access_provider_path,
|
||||
CAST('' AS ltree) access_provider_path,
|
||||
'domain' AS access_type,
|
||||
dr.entity_id AS access_provider_id
|
||||
FROM
|
||||
@@ -414,7 +414,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
|
||||
COALESCE(c.domain_id, '') AS domain_id,
|
||||
COALESCE(parent_group_id, '') AS parent_group_id,
|
||||
c.route,
|
||||
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
|
||||
COALESCE(g.path, CAST('' AS ltree)) AS parent_group_path,
|
||||
c.status,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
@@ -422,6 +422,8 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
|
||||
COALESCE(c.updated_by, '') AS updated_by
|
||||
FROM
|
||||
channels c
|
||||
LEFT JOIN
|
||||
groups g ON g.id = c.parent_group_id
|
||||
)
|
||||
SELECT
|
||||
c.*
|
||||
@@ -492,7 +494,7 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
bq := repo.userChannelsBaseQuery(domainID, userID)
|
||||
bq := userChannelsBaseQuery
|
||||
|
||||
connJoinQuery := `
|
||||
FROM
|
||||
@@ -517,6 +519,34 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
`
|
||||
}
|
||||
|
||||
dbPage, err := toDBChannelsPage(pm)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
dbPage.UserID = userID
|
||||
dbPage.DomainID = domainID
|
||||
|
||||
if pm.OnlyTotal {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_channels c
|
||||
%s;
|
||||
`, bq, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return channels.ChannelsPage{
|
||||
Page: channels.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
@@ -540,7 +570,8 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
c.access_provider_role_actions,
|
||||
COUNT(*) OVER() AS total_count
|
||||
%s
|
||||
%s
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
@@ -549,83 +580,54 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
|
||||
q = applyLimitOffset(q)
|
||||
|
||||
dbPage, err := toDBChannelsPage(pm)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
var items []channels.Channel
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbc := dbChannel{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := toChannel(dbc)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.route,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
%s
|
||||
%s
|
||||
) AS subquery;
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
page := channels.ChannelsPage{
|
||||
var total uint64
|
||||
var items []channels.Channel
|
||||
for rows.Next() {
|
||||
dbc := dbChannel{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
total = dbc.TotalCount
|
||||
|
||||
c, err := toChannel(dbc)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_channels c
|
||||
%s;
|
||||
`, bq, pageQuery)
|
||||
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
return channels.ChannelsPage{
|
||||
Channels: items,
|
||||
Page: channels.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *channelRepository) userChannelsBaseQuery(domainID, userID string) string {
|
||||
return fmt.Sprintf(`
|
||||
const userChannelsBaseQuery = `
|
||||
WITH direct_channels AS (
|
||||
select
|
||||
c.id,
|
||||
@@ -640,7 +642,7 @@ WITH direct_channels AS (
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
|
||||
COALESCE(pg.path, CAST('' AS ltree)) AS parent_group_path,
|
||||
cr.id AS role_id,
|
||||
cr."name" AS role_name,
|
||||
array_agg(cra."action") AS actions,
|
||||
@@ -648,7 +650,7 @@ WITH direct_channels AS (
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
CAST(array[] AS text[]) AS access_provider_role_actions
|
||||
FROM
|
||||
channels_role_members crm
|
||||
JOIN
|
||||
@@ -657,11 +659,13 @@ WITH direct_channels AS (
|
||||
channels_roles cr ON cr.id = crm.role_id
|
||||
JOIN
|
||||
channels c ON c.id = cr.entity_id
|
||||
LEFT JOIN
|
||||
groups pg ON pg.id = c.parent_group_id
|
||||
WHERE
|
||||
crm.member_id = '%s'
|
||||
AND c.domain_id = '%s'
|
||||
crm.member_id = :user_id
|
||||
AND c.domain_id = :domain_id_param
|
||||
GROUP BY
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id, pg.path
|
||||
),
|
||||
direct_groups AS (
|
||||
SELECT
|
||||
@@ -682,9 +686,9 @@ direct_groups AS (
|
||||
JOIN
|
||||
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
AND gra."action" LIKE 'channel%%'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
AND gra."action" LIKE 'channel%'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
@@ -707,9 +711,9 @@ direct_groups_with_subgroup AS (
|
||||
JOIN
|
||||
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
AND gra."action" LIKE 'subgroup_channel%%'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
AND gra."action" LIKE 'subgroup_channel%'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
@@ -737,7 +741,7 @@ indirect_child_groups AS (
|
||||
JOIN
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
|
||||
WHERE
|
||||
indirect_child_groups.domain_id = '%s'
|
||||
indirect_child_groups.domain_id = :domain_id_param
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM direct_groups_with_subgroup dgws
|
||||
@@ -759,7 +763,7 @@ final_groups AS (
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'direct_group' AS access_type,
|
||||
id AS access_provider_id,
|
||||
role_id AS access_provider_role_id,
|
||||
@@ -782,7 +786,7 @@ final_groups AS (
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'indirect_group' AS access_type,
|
||||
access_provider_id,
|
||||
access_provider_role_id,
|
||||
@@ -819,7 +823,7 @@ groups_channels AS (
|
||||
JOIN
|
||||
channels c ON c.parent_group_id = g.id
|
||||
WHERE
|
||||
c.id NOT IN (SELECT id FROM direct_channels)
|
||||
NOT EXISTS (SELECT 1 FROM direct_channels dc WHERE dc.id = c.id)
|
||||
UNION
|
||||
SELECT * FROM direct_channels
|
||||
),
|
||||
@@ -865,7 +869,7 @@ final_channels AS (
|
||||
g."path" AS parent_group_path,
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'domain' AS access_type,
|
||||
d.id AS access_provider_id,
|
||||
dr.id AS access_provider_role_id,
|
||||
@@ -884,18 +888,17 @@ final_channels AS (
|
||||
LEFT JOIN
|
||||
groups g ON dc.parent_group_id = g.id
|
||||
WHERE
|
||||
drm.member_id = '%s' -- user_id
|
||||
AND d.id = '%s' -- domain_id
|
||||
AND dra."action" LIKE 'channel_%%'
|
||||
AND NOT EXISTS ( -- Ensures that the direct and indirect channels are not in included.
|
||||
drm.member_id = :user_id
|
||||
AND d.id = :domain_id_param
|
||||
AND dra."action" LIKE 'channel_%'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM groups_channels gc
|
||||
WHERE gc.id = dc.id
|
||||
)
|
||||
GROUP BY
|
||||
dc.id, d.id, dr.id, g."path"
|
||||
)
|
||||
`, userID, domainID, userID, domainID, userID, domainID, domainID, userID, domainID)
|
||||
}
|
||||
`
|
||||
|
||||
func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
|
||||
q := "DELETE FROM channels AS c WHERE c.id = ANY(:channel_ids) ;"
|
||||
@@ -1145,6 +1148,7 @@ type dbChannel struct {
|
||||
ConnectionTypes pq.Int32Array `db:"connection_types,omitempty"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
func toDBChannel(ch channels.Channel) (dbChannel, error) {
|
||||
@@ -1303,7 +1307,7 @@ func PageQuery(pm channels.Page) (string, error) {
|
||||
}
|
||||
|
||||
if len(pm.IDs) != 0 {
|
||||
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
query = append(query, "id = ANY(:ids)")
|
||||
}
|
||||
if pm.Status != channels.AllStatus {
|
||||
query = append(query, "c.status = :status")
|
||||
@@ -1415,6 +1419,7 @@ func toDBChannelsPage(pm channels.Page) (dbChannelsPage, error) {
|
||||
RoleID: pm.RoleID,
|
||||
Actions: pm.Actions,
|
||||
AccessType: pm.AccessType,
|
||||
IDs: pq.StringArray(pm.IDs),
|
||||
CreatedFrom: pm.CreatedFrom,
|
||||
CreatedTo: pm.CreatedTo,
|
||||
}, nil
|
||||
@@ -1438,6 +1443,9 @@ type dbChannelsPage struct {
|
||||
AccessType string `db:"access_type"`
|
||||
CreatedFrom time.Time `db:"created_from"`
|
||||
CreatedTo time.Time `db:"created_to"`
|
||||
IDs pq.StringArray `db:"ids"`
|
||||
UserID string `db:"user_id"`
|
||||
DomainID string `db:"domain_id_param"`
|
||||
}
|
||||
|
||||
type dbConnection struct {
|
||||
|
||||
@@ -95,6 +95,17 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
`SELECT 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_06",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_channels_domain_id_status ON channels(domain_id, status);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_channels_parent_group_id ON channels(parent_group_id);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_channels_domain_id_status;`,
|
||||
`DROP INDEX IF EXISTS idx_channels_parent_group_id;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
|
||||
|
||||
+105
-98
@@ -195,7 +195,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
|
||||
SELECT
|
||||
c.id,
|
||||
c.parent_group_id,
|
||||
COALESCE(g."path", ''::::ltree) AS parent_group_path,
|
||||
COALESCE(g."path", CAST('' AS ltree)) AS parent_group_path,
|
||||
c.domain_id
|
||||
FROM
|
||||
clients c
|
||||
@@ -213,7 +213,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
|
||||
cr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT cra."action") AS actions,
|
||||
'direct' AS access_type,
|
||||
''::::ltree AS access_provider_path,
|
||||
CAST('' AS ltree) AS access_provider_path,
|
||||
'' AS access_provider_id
|
||||
FROM
|
||||
clients_roles cr
|
||||
@@ -272,7 +272,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
|
||||
dr.id AS role_id,
|
||||
dr."name" AS role_name,
|
||||
jsonb_agg(DISTINCT all_actions."action") AS actions,
|
||||
''::::ltree access_provider_path,
|
||||
CAST('' AS ltree) access_provider_path,
|
||||
'domain' AS access_type,
|
||||
dr.entity_id AS access_provider_id
|
||||
FROM
|
||||
@@ -452,13 +452,15 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
|
||||
c.metadata,
|
||||
COALESCE(c.domain_id, '') AS domain_id,
|
||||
COALESCE(parent_group_id, '') AS parent_group_id,
|
||||
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
|
||||
COALESCE(g.path, CAST('' AS ltree)) AS parent_group_path,
|
||||
c.status,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
COALESCE(c.updated_by, '') AS updated_by
|
||||
FROM
|
||||
clients c
|
||||
LEFT JOIN
|
||||
groups g ON g.id = c.parent_group_id
|
||||
)
|
||||
SELECT
|
||||
c.*
|
||||
@@ -529,7 +531,7 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
bq := repo.userClientBaseQuery(domainID, userID)
|
||||
bq := userClientBaseQuery
|
||||
|
||||
connJoinQuery := `
|
||||
FROM
|
||||
@@ -554,6 +556,34 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
`
|
||||
}
|
||||
|
||||
dbPage, err := ToDBClientsPage(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
dbPage.UserID = userID
|
||||
dbPage.DomainID = domainID
|
||||
|
||||
if pm.OnlyTotal {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_clients c
|
||||
%s;
|
||||
`, bq, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
@@ -577,7 +607,8 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
c.access_provider_role_actions,
|
||||
COUNT(*) OVER() AS total_count
|
||||
%s
|
||||
%s
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
@@ -586,83 +617,54 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
|
||||
q = applyLimitOffset(q)
|
||||
|
||||
dbPage, err := ToDBClientsPage(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
var items []clients.Client
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.identity,
|
||||
c.secret,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
%s
|
||||
%s
|
||||
) AS subquery;
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
page := clients.ClientsPage{
|
||||
var total uint64
|
||||
var items []clients.Client
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
total = dbc.TotalCount
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_clients c
|
||||
%s;
|
||||
`, bq, pageQuery)
|
||||
|
||||
total, err = postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
return clients.ClientsPage{
|
||||
Clients: items,
|
||||
Page: clients.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
return fmt.Sprintf(`
|
||||
const userClientBaseQuery = `
|
||||
WITH direct_clients AS (
|
||||
SELECT
|
||||
c.id,
|
||||
@@ -677,7 +679,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
|
||||
COALESCE(pg.path, CAST('' AS ltree)) AS parent_group_path,
|
||||
cr.id AS role_id,
|
||||
cr."name" AS role_name,
|
||||
array_agg(cra."action") AS actions,
|
||||
@@ -685,7 +687,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
CAST(array[] AS text[]) AS access_provider_role_actions
|
||||
FROM
|
||||
clients_role_members crm
|
||||
JOIN
|
||||
@@ -694,11 +696,13 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
clients_roles cr ON cr.id = crm.role_id
|
||||
JOIN
|
||||
clients c ON c.id = cr.entity_id
|
||||
LEFT JOIN
|
||||
groups pg ON pg.id = c.parent_group_id
|
||||
WHERE
|
||||
crm.member_id = '%s'
|
||||
AND c.domain_id = '%s'
|
||||
crm.member_id = :user_id
|
||||
AND c.domain_id = :domain_id_param
|
||||
GROUP BY
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id, pg.path
|
||||
),
|
||||
direct_groups AS (
|
||||
SELECT
|
||||
@@ -719,9 +723,9 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
JOIN
|
||||
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
AND gra."action" LIKE 'client%%'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
AND gra."action" LIKE 'client%'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
@@ -744,9 +748,9 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
JOIN
|
||||
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
AND gra."action" LIKE 'subgroup_client%%'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
AND gra."action" LIKE 'subgroup_client%'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
@@ -772,10 +776,10 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
FROM
|
||||
direct_leaf_groups_with_subgroup dlgws
|
||||
JOIN
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path -- Finds all children of entity_id based on ltree path
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
|
||||
WHERE
|
||||
indirect_child_groups.domain_id = '%s'
|
||||
AND NOT EXISTS (
|
||||
indirect_child_groups.domain_id = :domain_id_param
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM direct_groups_with_subgroup dgws
|
||||
WHERE dgws.id = indirect_child_groups.id
|
||||
@@ -796,7 +800,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'direct_group' AS access_type,
|
||||
id AS access_provider_id,
|
||||
role_id AS access_provider_role_id,
|
||||
@@ -819,7 +823,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'indirect_group' AS access_type,
|
||||
access_provider_id,
|
||||
access_provider_role_id,
|
||||
@@ -856,7 +860,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
JOIN
|
||||
clients c ON c.parent_group_id = g.id
|
||||
WHERE
|
||||
c.id NOT IN (SELECT id FROM direct_clients)
|
||||
NOT EXISTS (SELECT 1 FROM direct_clients dc WHERE dc.id = c.id)
|
||||
UNION
|
||||
SELECT * FROM direct_clients
|
||||
),
|
||||
@@ -902,7 +906,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
g."path" AS parent_group_path,
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'domain' AS access_type,
|
||||
d.id AS access_provider_id,
|
||||
dr.id AS access_provider_role_id,
|
||||
@@ -921,18 +925,17 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
LEFT JOIN
|
||||
groups g ON dc.parent_group_id = g.id
|
||||
WHERE
|
||||
drm.member_id = '%s' -- user_id
|
||||
AND d.id = '%s' -- domain_id
|
||||
AND dra."action" LIKE 'client_%%'
|
||||
AND NOT EXISTS ( -- Ensures that the direct and indirect clients are not in included.
|
||||
drm.member_id = :user_id
|
||||
AND d.id = :domain_id_param
|
||||
AND dra."action" LIKE 'client_%'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM groups_clients gc
|
||||
WHERE gc.id = dc.id
|
||||
)
|
||||
GROUP BY
|
||||
dc.id, d.id, dr.id, g."path"
|
||||
)
|
||||
`, userID, domainID, userID, domainID, userID, domainID, domainID, userID, domainID)
|
||||
}
|
||||
`
|
||||
|
||||
func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
|
||||
query, err := PageQuery(pm)
|
||||
@@ -1056,6 +1059,7 @@ type DBClient struct {
|
||||
ConnectionTypes pq.Int32Array `db:"connection_types,omitempty"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
func ToDBClient(c clients.Client) (DBClient, error) {
|
||||
@@ -1206,6 +1210,7 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) {
|
||||
RoleID: pm.RoleID,
|
||||
Actions: pm.Actions,
|
||||
AccessType: pm.AccessType,
|
||||
IDs: pq.StringArray(pm.IDs),
|
||||
CreatedFrom: pm.CreatedFrom,
|
||||
CreatedTo: pm.CreatedTo,
|
||||
}, nil
|
||||
@@ -1230,6 +1235,9 @@ type dbClientsPage struct {
|
||||
AccessType string `db:"access_type"`
|
||||
CreatedFrom time.Time `db:"created_from"`
|
||||
CreatedTo time.Time `db:"created_to"`
|
||||
IDs pq.StringArray `db:"ids"`
|
||||
UserID string `db:"user_id"`
|
||||
DomainID string `db:"domain_id_param"`
|
||||
}
|
||||
|
||||
func PageQuery(pm clients.Page) (string, error) {
|
||||
@@ -1252,7 +1260,7 @@ func PageQuery(pm clients.Page) (string, error) {
|
||||
}
|
||||
}
|
||||
if len(pm.IDs) != 0 {
|
||||
query = append(query, fmt.Sprintf("c.id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
query = append(query, "c.id = ANY(:ids)")
|
||||
}
|
||||
|
||||
if pm.Status != clients.AllStatus {
|
||||
@@ -1432,9 +1440,8 @@ func (repo *clientRepo) RemoveConnections(ctx context.Context, conns []clients.C
|
||||
}
|
||||
}()
|
||||
|
||||
query := `DELETE FROM connections WHERE channel_id = :channel_id AND domain_id = :domain_id AND client_id = :client_id`
|
||||
|
||||
for _, conn := range conns {
|
||||
query := `DELETE FROM connections WHERE channel_id = :channel_id AND domain_id = :domain_id AND client_id = :client_id`
|
||||
if uint8(conn.Type) > 0 {
|
||||
query = query + " AND type = :type "
|
||||
}
|
||||
|
||||
@@ -99,6 +99,19 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
`SELECT 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "clients_06",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_clients_domain_id_status ON clients(domain_id, status);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_clients_parent_group_id ON clients(parent_group_id);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_connections_client_id ON connections(client_id);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_clients_domain_id_status;`,
|
||||
`DROP INDEX IF EXISTS idx_clients_parent_group_id;`,
|
||||
`DROP INDEX IF EXISTS idx_connections_client_id;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+90
-76
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
@@ -252,7 +251,6 @@ func (repo domainRepo) RetrieveDomainByRoute(ctx context.Context, route string)
|
||||
|
||||
// RetrieveAllByIDs retrieves for given Domain IDs .
|
||||
func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
|
||||
var q string
|
||||
if len(pm.IDs) == 0 {
|
||||
return domains.DomainsPage{}, nil
|
||||
}
|
||||
@@ -263,11 +261,11 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
|
||||
|
||||
baseQ := `SELECT d.id as id, d.name as name, d.tags as tags, d.route as route, d.metadata as metadata,
|
||||
d.created_at as created_at, d.updated_at as updated_at, d.updated_by as updated_by,
|
||||
d.created_by as created_by, d.status as status FROM domains d`
|
||||
d.created_by as created_by, d.status as status, COUNT(*) OVER() AS total_count FROM domains d`
|
||||
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
q = fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", baseQ, squery, pm.Limit, pm.Offset)
|
||||
q := fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", baseQ, squery, pm.Limit, pm.Offset)
|
||||
|
||||
dbPage, err := toDBDomainsPage(pm)
|
||||
if err != nil {
|
||||
@@ -280,19 +278,30 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
doms, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
var total uint64
|
||||
var doms []domains.Domain
|
||||
for rows.Next() {
|
||||
dbd := dbDomain{}
|
||||
if err := rows.StructScan(&dbd); err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
total = dbd.TotalCount
|
||||
d, err := toDomain(dbd)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
doms = append(doms, d)
|
||||
}
|
||||
|
||||
cq := "SELECT COUNT(*) FROM domains d"
|
||||
if query != "" {
|
||||
cq = fmt.Sprintf(" %s %s", cq, query)
|
||||
}
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
if len(doms) == 0 {
|
||||
cq := "SELECT COUNT(*) FROM domains d"
|
||||
if query != "" {
|
||||
cq = fmt.Sprintf(" %s %s", cq, query)
|
||||
}
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
return domains.DomainsPage{
|
||||
@@ -321,14 +330,15 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
|
||||
d.updated_at as updated_at,
|
||||
d.updated_by as updated_by,
|
||||
d.created_by as created_by,
|
||||
d.status as status
|
||||
d.status as status,
|
||||
COUNT(*) OVER() AS total_count
|
||||
FROM
|
||||
domains as d
|
||||
%s
|
||||
LIMIT :limit OFFSET :offset`
|
||||
|
||||
if pm.UserID != "" {
|
||||
q = repo.userDomainsBaseQuery() +
|
||||
q = userDomainsBaseQuery +
|
||||
`
|
||||
SELECT
|
||||
d.id as id,
|
||||
@@ -343,7 +353,8 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
|
||||
d.created_at as created_at,
|
||||
d.updated_at as updated_at,
|
||||
d.updated_by as updated_by,
|
||||
d.created_by as created_by
|
||||
d.created_by as created_by,
|
||||
COUNT(*) OVER() AS total_count
|
||||
FROM
|
||||
domains d
|
||||
%s
|
||||
@@ -358,35 +369,55 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
var doms []domains.Domain
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
doms, err = repo.processRows(rows)
|
||||
if pm.OnlyTotal {
|
||||
cq := `SELECT COUNT(*) FROM domains as d %s`
|
||||
if pm.UserID != "" {
|
||||
cq = userDomainsBaseQuery + cq
|
||||
}
|
||||
if query != "" {
|
||||
cq = fmt.Sprintf(cq, query)
|
||||
}
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
return domains.DomainsPage{Total: total, Offset: pm.Offset, Limit: pm.Limit}, nil
|
||||
}
|
||||
|
||||
cq := `SELECT COUNT(*)
|
||||
FROM domains as d %s`
|
||||
|
||||
if pm.UserID != "" {
|
||||
cq = repo.userDomainsBaseQuery() + cq
|
||||
}
|
||||
|
||||
if query != "" {
|
||||
cq = fmt.Sprintf(cq, query)
|
||||
}
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var doms []domains.Domain
|
||||
for rows.Next() {
|
||||
dbd := dbDomain{}
|
||||
if err := rows.StructScan(&dbd); err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
total = dbd.TotalCount
|
||||
d, err := toDomain(dbd)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
doms = append(doms, d)
|
||||
}
|
||||
|
||||
if len(doms) == 0 {
|
||||
cq := `SELECT COUNT(*) FROM domains as d %s`
|
||||
if pm.UserID != "" {
|
||||
cq = userDomainsBaseQuery + cq
|
||||
}
|
||||
if query != "" {
|
||||
cq = fmt.Sprintf(cq, query)
|
||||
}
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
return domains.DomainsPage{
|
||||
Total: total,
|
||||
@@ -479,8 +510,7 @@ func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo domainRepo) userDomainsBaseQuery() string {
|
||||
return `
|
||||
const userDomainsBaseQuery = `
|
||||
with domains AS (
|
||||
SELECT
|
||||
d.id as id,
|
||||
@@ -511,23 +541,6 @@ func (repo domainRepo) userDomainsBaseQuery() string {
|
||||
GROUP BY
|
||||
dr.entity_id, drm.member_id, dr.id, dr."name", d.id
|
||||
)`
|
||||
}
|
||||
|
||||
func (repo domainRepo) processRows(rows *sqlx.Rows) ([]domains.Domain, error) {
|
||||
var items []domains.Domain
|
||||
for rows.Next() {
|
||||
dbd := dbDomain{}
|
||||
if err := rows.StructScan(&dbd); err != nil {
|
||||
return items, err
|
||||
}
|
||||
d, err := toDomain(dbd)
|
||||
if err != nil {
|
||||
return items, err
|
||||
}
|
||||
items = append(items, d)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func applyOrdering(emq string, pm domains.Page) string {
|
||||
col := "COALESCE(d.updated_at, d.created_at)"
|
||||
@@ -550,21 +563,22 @@ func applyOrdering(emq string, pm domains.Page) string {
|
||||
}
|
||||
|
||||
type dbDomain struct {
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Route *string `db:"route,omitempty"`
|
||||
Status domains.Status `db:"status"`
|
||||
RoleID string `db:"role_id"`
|
||||
RoleName string `db:"role_name"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
CreatedBy string `db:"created_by"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Route *string `db:"route,omitempty"`
|
||||
Status domains.Status `db:"status"`
|
||||
RoleID string `db:"role_id"`
|
||||
RoleName string `db:"role_name"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
CreatedBy string `db:"created_by"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
func toDBDomain(d domains.Domain) (dbDomain, error) {
|
||||
@@ -670,7 +684,7 @@ type dbDomainsPage struct {
|
||||
RoleName string `db:"role_name"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
ID string `db:"id"`
|
||||
IDs []string `db:"ids"`
|
||||
IDs pq.StringArray `db:"ids"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tags pgtype.TextArray `db:"tags"`
|
||||
Status domains.Status `db:"status"`
|
||||
@@ -699,7 +713,7 @@ func toDBDomainsPage(pm domains.Page) (dbDomainsPage, error) {
|
||||
RoleName: pm.RoleName,
|
||||
Actions: pm.Actions,
|
||||
ID: pm.ID,
|
||||
IDs: pm.IDs,
|
||||
IDs: pq.StringArray(pm.IDs),
|
||||
Metadata: data,
|
||||
Tags: tags,
|
||||
Status: pm.Status,
|
||||
@@ -718,7 +732,7 @@ func buildPageQuery(pm domains.Page) (string, error) {
|
||||
}
|
||||
|
||||
if len(pm.IDs) != 0 {
|
||||
query = append(query, fmt.Sprintf("d.id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
query = append(query, "d.id = ANY(:ids)")
|
||||
}
|
||||
|
||||
if (pm.Status >= domains.EnabledStatus) && (pm.Status < domains.AllStatus) {
|
||||
|
||||
@@ -160,6 +160,17 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
END $$;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "domain_6",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_invitations_invited_by ON invitations(invited_by);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_invitations_role_id ON invitations(role_id);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_invitations_invited_by;`,
|
||||
`DROP INDEX IF EXISTS idx_invitations_role_id;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "domain_7",
|
||||
Up: []string{
|
||||
|
||||
+188
-125
@@ -60,7 +60,7 @@ func New(db postgres.Database) groups.Repository {
|
||||
}
|
||||
|
||||
func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Group, error) {
|
||||
q, err := repo.getInsertQuery(ctx, g)
|
||||
q, computedPath, err := repo.getInsertQuery(ctx, g)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
@@ -68,6 +68,9 @@ func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Gr
|
||||
if err != nil {
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
if computedPath != "" {
|
||||
dbg.Path = computedPath
|
||||
}
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
|
||||
if err != nil {
|
||||
@@ -256,7 +259,7 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
|
||||
dr.id AS role_id,
|
||||
dr.name AS role_name,
|
||||
jsonb_agg(DISTINCT all_actions.action) AS actions,
|
||||
''::::ltree access_provider_path,
|
||||
CAST('' AS ltree) access_provider_path,
|
||||
'domain' AS access_type,
|
||||
dr.entity_id AS access_provider_id
|
||||
FROM
|
||||
@@ -362,9 +365,9 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
|
||||
}
|
||||
|
||||
func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, userID, groupID string) (groups.Group, error) {
|
||||
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
|
||||
baseQuery := userGroupsBaseQuery
|
||||
|
||||
dbg := dbGroup{ID: groupID}
|
||||
dbg := dbGroup{ID: groupID, UserID: userID, DomainIDParam: domainID}
|
||||
q := fmt.Sprintf(`%s
|
||||
SELECT
|
||||
g.id,
|
||||
@@ -438,39 +441,54 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
|
||||
orderClause = fmt.Sprintf("ORDER BY %s %s, g.id %s", orderBy, dir, dir)
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`SELECT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
|
||||
|
||||
dbPageMeta, err := toDBGroupPageMeta(pm)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
var items []groups.Group
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err = repo.processRows(rows)
|
||||
if pm.OnlyTotal {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM groups g %s;`, query)
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
page := groups.Page{PageMeta: pm}
|
||||
page.Total = total
|
||||
return page, nil
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
|
||||
) AS subquery;
|
||||
`, query)
|
||||
q := fmt.Sprintf(`SELECT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status,
|
||||
COUNT(*) OVER() AS total_count FROM groups g %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []groups.Group
|
||||
for rows.Next() {
|
||||
dbg := dbGroup{}
|
||||
if err := rows.StructScan(&dbg); err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
total = dbg.TotalCount
|
||||
g, err := toGroup(dbg)
|
||||
if err != nil {
|
||||
return groups.Page{}, err
|
||||
}
|
||||
items = append(items, g)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM groups g %s;`, query)
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
page.Total = total
|
||||
@@ -479,40 +497,49 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
|
||||
}
|
||||
|
||||
func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMeta, ids ...string) (groups.Page, error) {
|
||||
var q string
|
||||
if (len(ids) == 0) && (pm.DomainID == "") {
|
||||
return groups.Page{PageMeta: groups.PageMeta{Offset: pm.Offset, Limit: pm.Limit}}, nil
|
||||
}
|
||||
query := buildQuery(pm, ids...)
|
||||
|
||||
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
|
||||
q := fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status,
|
||||
COUNT(*) OVER() AS total_count FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
|
||||
|
||||
dbPageMeta, err := toDBGroupPageMeta(pm)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
dbPageMeta.IDs = pq.StringArray(ids)
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
var total uint64
|
||||
var items []groups.Group
|
||||
for rows.Next() {
|
||||
dbg := dbGroup{}
|
||||
if err := rows.StructScan(&dbg); err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
total = dbg.TotalCount
|
||||
g, err := toGroup(dbg)
|
||||
if err != nil {
|
||||
return groups.Page{}, err
|
||||
}
|
||||
items = append(items, g)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
|
||||
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
|
||||
) AS subquery;
|
||||
`, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (
|
||||
SELECT DISTINCT g.id FROM groups g %s
|
||||
) AS subquery;`, query)
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
@@ -530,7 +557,7 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
|
||||
dirQuery = "g.path <@ (SELECT path FROM groups WHERE id = :id)"
|
||||
}
|
||||
|
||||
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
|
||||
baseQuery := userGroupsBaseQuery
|
||||
query := fmt.Sprintf(`%s,
|
||||
target_hierarchy AS (
|
||||
SELECT
|
||||
@@ -568,8 +595,10 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
|
||||
`, baseQuery, dirQuery)
|
||||
|
||||
parameters := map[string]any{
|
||||
"id": groupID,
|
||||
"level": hm.Level,
|
||||
"id": groupID,
|
||||
"level": hm.Level,
|
||||
"user_id": userID,
|
||||
"domain_id_param": domainID,
|
||||
}
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, query, parameters)
|
||||
@@ -798,7 +827,7 @@ func (repo groupRepository) RetrieveAllParentGroups(ctx context.Context, domainI
|
||||
|
||||
query := buildQuery(pm)
|
||||
|
||||
levelCondition := fmt.Sprintf("g.path @> '%s' ", cGroup.Path)
|
||||
levelCondition := "g.path @> CAST(:path AS ltree) "
|
||||
|
||||
switch {
|
||||
case query == "":
|
||||
@@ -807,6 +836,7 @@ func (repo groupRepository) RetrieveAllParentGroups(ctx context.Context, domainI
|
||||
query = query + " AND " + levelCondition
|
||||
}
|
||||
|
||||
pm.Path = cGroup.Path
|
||||
return repo.retrieveGroups(ctx, domainID, userID, query, pm)
|
||||
}
|
||||
|
||||
@@ -822,19 +852,19 @@ func (repo groupRepository) RetrieveChildrenGroups(ctx context.Context, domainID
|
||||
switch {
|
||||
// Retrieve all children groups from parent group level
|
||||
case startLevel == 0 && endLevel < 0:
|
||||
levelCondition = fmt.Sprintf(" path ~ '%s.*'::::lquery ", pGroup.Path)
|
||||
levelCondition = " path ~ CAST(:path || '.*' AS lquery) "
|
||||
|
||||
// Retrieve specific level of children groups from parent group level
|
||||
case (startLevel > 0) && (startLevel == endLevel || endLevel == 0):
|
||||
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d}'::::lquery ", pGroup.Path, startLevel)
|
||||
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d}' AS lquery) ", startLevel)
|
||||
|
||||
// Retrieve all children groups from specific level from parent group level
|
||||
case startLevel > 0 && endLevel < 0:
|
||||
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d,}'::::lquery ", pGroup.Path, startLevel)
|
||||
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d,}' AS lquery) ", startLevel)
|
||||
|
||||
// Retrieve children groups between specific level from parent group level
|
||||
case startLevel > 0 && endLevel > 0 && startLevel < endLevel:
|
||||
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d,%d}'::::lquery ", pGroup.Path, startLevel, endLevel)
|
||||
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d,%d}' AS lquery) ", startLevel, endLevel)
|
||||
default:
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid level range: start level: %d end level: %d", startLevel, endLevel))
|
||||
}
|
||||
@@ -846,6 +876,7 @@ func (repo groupRepository) RetrieveChildrenGroups(ctx context.Context, domainID
|
||||
query = query + " AND " + levelCondition
|
||||
}
|
||||
|
||||
pm.Path = pGroup.Path
|
||||
return repo.retrieveGroups(ctx, domainID, userID, query, pm)
|
||||
}
|
||||
|
||||
@@ -868,7 +899,7 @@ func (repo groupRepository) RetrieveUserGroups(ctx context.Context, domainID, us
|
||||
}
|
||||
|
||||
func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID, query string, pm groups.PageMeta) (groups.Page, error) {
|
||||
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
|
||||
baseQuery := userGroupsBaseQuery
|
||||
|
||||
orderClause := ""
|
||||
var orderBy string
|
||||
@@ -889,6 +920,30 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
|
||||
orderClause = fmt.Sprintf("ORDER BY %s %s, g.id %s", orderBy, dir, dir)
|
||||
}
|
||||
|
||||
dbPageMeta, err := toDBGroupPageMeta(pm)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
dbPageMeta.UserID = userID
|
||||
dbPageMeta.DomainIDParam = domainID
|
||||
|
||||
if pm.OnlyTotal {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_groups g
|
||||
%s;
|
||||
`, baseQuery, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
page.Total = total
|
||||
return page, nil
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`%s
|
||||
SELECT
|
||||
g.id,
|
||||
@@ -910,45 +965,49 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
|
||||
g.access_provider_id,
|
||||
g.access_provider_role_id,
|
||||
g.access_provider_role_name,
|
||||
g.access_provider_role_actions
|
||||
g.access_provider_role_actions,
|
||||
COUNT(*) OVER() AS total_count
|
||||
FROM final_groups g
|
||||
%s
|
||||
%s
|
||||
LIMIT :limit OFFSET :offset;`,
|
||||
baseQuery, query, orderClause)
|
||||
|
||||
dbPageMeta, err := toDBGroupPageMeta(pm)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
var items []groups.Group
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err = repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT g.id
|
||||
FROM final_groups g
|
||||
%s
|
||||
) AS subquery;`,
|
||||
baseQuery, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []groups.Group
|
||||
for rows.Next() {
|
||||
dbg := dbGroup{}
|
||||
if err := rows.StructScan(&dbg); err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
total = dbg.TotalCount
|
||||
|
||||
group, err := toGroup(dbg)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
items = append(items, group)
|
||||
}
|
||||
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM final_groups g
|
||||
%s;
|
||||
`, baseQuery, query)
|
||||
|
||||
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
page.Total = total
|
||||
@@ -956,8 +1015,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo groupRepository) userGroupsBaseQuery(domainID, userID string) string {
|
||||
return fmt.Sprintf(`
|
||||
const userGroupsBaseQuery = `
|
||||
WITH direct_groups AS (
|
||||
SELECT
|
||||
g.*,
|
||||
@@ -975,8 +1033,8 @@ JOIN
|
||||
JOIN
|
||||
"groups" g ON g.id = gr.entity_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
@@ -997,12 +1055,12 @@ direct_groups_with_subgroup AS (
|
||||
JOIN
|
||||
"groups" g ON g.id = gr.entity_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
grm.member_id = :user_id
|
||||
AND g.domain_id = :domain_id_param
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
HAVING
|
||||
bool_or(gra."action" LIKE 'subgroup_%%')
|
||||
bool_or(gra."action" LIKE 'subgroup_%')
|
||||
),
|
||||
direct_leaf_groups_with_subgroup AS (
|
||||
SELECT dgws.*
|
||||
@@ -1026,11 +1084,10 @@ indirect_child_groups AS (
|
||||
FROM
|
||||
direct_leaf_groups_with_subgroup dlgws
|
||||
JOIN
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path -- Finds all children of entity_id based on ltree path
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
|
||||
WHERE
|
||||
indirect_child_groups.domain_id = '%s'
|
||||
AND
|
||||
NOT EXISTS ( -- Ensures that the indirect_child_groups.id is not already in the direct_groups_with_subgroup table
|
||||
indirect_child_groups.domain_id = :domain_id_param
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM direct_groups_with_subgroup dgws
|
||||
WHERE dgws.id = indirect_child_groups.id
|
||||
@@ -1057,7 +1114,7 @@ direct_indirect_groups as (
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
CAST(array[] AS text[]) AS access_provider_role_actions
|
||||
FROM
|
||||
direct_groups
|
||||
UNION
|
||||
@@ -1076,7 +1133,7 @@ direct_indirect_groups as (
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'indirect' AS access_type,
|
||||
access_provider_id,
|
||||
access_provider_role_id,
|
||||
@@ -1125,7 +1182,7 @@ final_groups AS (
|
||||
dg."path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
CAST(array[] AS text[]) AS actions,
|
||||
'domain' AS access_type,
|
||||
d.id AS access_provider_id,
|
||||
dr.id AS access_provider_role_id,
|
||||
@@ -1142,24 +1199,23 @@ final_groups AS (
|
||||
JOIN
|
||||
"groups" dg ON dg.domain_id = d.id
|
||||
WHERE
|
||||
drm.member_id = '%s' -- user_id
|
||||
AND d.id = '%s' -- domain_id
|
||||
AND dra."action" LIKE 'group_%%'
|
||||
AND NOT EXISTS ( -- Ensures that the direct and indirect groups are not in included.
|
||||
drm.member_id = :user_id
|
||||
AND d.id = :domain_id_param
|
||||
AND dra."action" LIKE 'group_%'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM direct_indirect_groups dig
|
||||
WHERE dig.id = dg.id
|
||||
)
|
||||
GROUP BY
|
||||
dg.id, d.id, dr.id
|
||||
)
|
||||
`, userID, domainID, userID, domainID, domainID, userID, domainID)
|
||||
}
|
||||
`
|
||||
|
||||
func buildQuery(gm groups.PageMeta, ids ...string) string {
|
||||
queries := []string{}
|
||||
|
||||
if len(ids) > 0 {
|
||||
queries = append(queries, fmt.Sprintf(" id in ('%s') ", strings.Join(ids, "', '")))
|
||||
queries = append(queries, "id = ANY(:ids)")
|
||||
}
|
||||
if gm.Name != "" {
|
||||
queries = append(queries, "g.name ILIKE '%' || :name || '%'")
|
||||
@@ -1233,6 +1289,9 @@ type dbGroup struct {
|
||||
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions"`
|
||||
MemberID string `db:"member_id,omitempty"`
|
||||
Roles json.RawMessage `db:"roles,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
UserID string `db:"user_id,omitempty"`
|
||||
DomainIDParam string `db:"domain_id_param,omitempty"`
|
||||
}
|
||||
|
||||
func toDBGroup(g groups.Group) (dbGroup, error) {
|
||||
@@ -1360,31 +1419,35 @@ func toDBGroupPageMeta(pm groups.PageMeta) (dbGroupPageMeta, error) {
|
||||
RoleID: pm.RoleID,
|
||||
Actions: pm.Actions,
|
||||
AccessType: pm.AccessType,
|
||||
Path: pm.Path,
|
||||
CreatedFrom: pm.CreatedFrom,
|
||||
CreatedTo: pm.CreatedTo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dbGroupPageMeta struct {
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name"`
|
||||
ParentID string `db:"parent_id"`
|
||||
DomainID string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Path string `db:"path"`
|
||||
Level uint64 `db:"level"`
|
||||
Total uint64 `db:"total"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Subject string `db:"subject"`
|
||||
RoleName string `db:"role_name"`
|
||||
RoleID string `db:"role_id"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
AccessType string `db:"access_type"`
|
||||
Status groups.Status `db:"status"`
|
||||
Tags pgtype.TextArray `db:"tags"`
|
||||
CreatedFrom time.Time `db:"created_from"`
|
||||
CreatedTo time.Time `db:"created_to"`
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name"`
|
||||
ParentID string `db:"parent_id"`
|
||||
DomainID string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Path string `db:"path"`
|
||||
Level uint64 `db:"level"`
|
||||
Total uint64 `db:"total"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Subject string `db:"subject"`
|
||||
RoleName string `db:"role_name"`
|
||||
RoleID string `db:"role_id"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
AccessType string `db:"access_type"`
|
||||
Status groups.Status `db:"status"`
|
||||
Tags pgtype.TextArray `db:"tags"`
|
||||
IDs pq.StringArray `db:"ids"`
|
||||
CreatedFrom time.Time `db:"created_from"`
|
||||
CreatedTo time.Time `db:"created_to"`
|
||||
UserID string `db:"user_id"`
|
||||
DomainIDParam string `db:"domain_id_param"`
|
||||
}
|
||||
|
||||
func (repo groupRepository) processRows(rows *sqlx.Rows) ([]groups.Group, error) {
|
||||
@@ -1403,23 +1466,23 @@ func (repo groupRepository) processRows(rows *sqlx.Rows) ([]groups.Group, error)
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (repo groupRepository) getInsertQuery(c context.Context, g groups.Group) (string, error) {
|
||||
func (repo groupRepository) getInsertQuery(c context.Context, g groups.Group) (string, string, error) {
|
||||
switch {
|
||||
case g.Parent != "":
|
||||
parent, err := repo.RetrieveByID(c, g.Parent)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
path := parent.Path + "." + g.ID
|
||||
if len(strings.Split(path, ".")) > groups.MaxPathLength {
|
||||
return "", fmt.Errorf("reached max nested depth")
|
||||
return "", "", fmt.Errorf("reached max nested depth")
|
||||
}
|
||||
return fmt.Sprintf(`INSERT INTO groups (name, description, tags, id, domain_id, parent_id, metadata, created_at, status, path)
|
||||
VALUES (:name, :description, :tags, :id, :domain_id, :parent_id, :metadata, :created_at, :status, '%s')
|
||||
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path), nil
|
||||
return `INSERT INTO groups (name, description, tags, id, domain_id, parent_id, metadata, created_at, status, path)
|
||||
VALUES (:name, :description, :tags, :id, :domain_id, :parent_id, :metadata, :created_at, :status, CAST(:path AS ltree))
|
||||
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path, nil
|
||||
default:
|
||||
return `INSERT INTO groups (name, description, tags, id, domain_id, metadata, created_at, status, path)
|
||||
VALUES (:name, :description, :tags, :id, :domain_id, :metadata, :created_at, :status, :id)
|
||||
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, nil
|
||||
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, "", nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,6 +95,15 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
`SELECT 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "groups_07",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_groups_domain_id_status ON groups(domain_id, status);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_groups_domain_id_status;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,15 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
`ALTER TABLE clients_telemetry ALTER COLUMN last_seen TYPE TIMESTAMP;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "journal_03",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_journal_domain ON journal(domain);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_journal_domain;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+12
-10
@@ -71,7 +71,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
if page.Direction == "" {
|
||||
page.Direction = "ASC"
|
||||
}
|
||||
q := fmt.Sprintf("SELECT %s FROM journal %s ORDER BY occurred_at %s LIMIT :limit OFFSET :offset;", sq, query, page.Direction)
|
||||
q := fmt.Sprintf("SELECT %s, COUNT(*) OVER() AS total_count FROM journal %s ORDER BY occurred_at %s LIMIT :limit OFFSET :offset;", sq, query, page.Direction)
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, page)
|
||||
if err != nil {
|
||||
@@ -79,12 +79,14 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []journal.Journal
|
||||
for rows.Next() {
|
||||
var item dbJournal
|
||||
if err = rows.StructScan(&item); err != nil {
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
total = item.TotalCount
|
||||
j, err := toJournal(item)
|
||||
if err != nil {
|
||||
return journal.JournalsPage{}, err
|
||||
@@ -92,21 +94,20 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
items = append(items, j)
|
||||
}
|
||||
|
||||
tq := fmt.Sprintf(`SELECT COUNT(*) FROM journal %s;`, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, tq, page)
|
||||
if err != nil {
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
if len(items) == 0 {
|
||||
tq := fmt.Sprintf(`SELECT COUNT(*) FROM journal %s;`, query)
|
||||
total, err = postgres.Total(ctx, repo.db, tq, page)
|
||||
if err != nil {
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
journalsPage := journal.JournalsPage{
|
||||
return journal.JournalsPage{
|
||||
Total: total,
|
||||
Offset: page.Offset,
|
||||
Limit: page.Limit,
|
||||
Journals: items,
|
||||
}
|
||||
|
||||
return journalsPage, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func pageQuery(pm journal.Page) string {
|
||||
@@ -139,6 +140,7 @@ type dbJournal struct {
|
||||
OccurredAt time.Time `db:"occurred_at"`
|
||||
Attributes []byte `db:"attributes"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
func toDBJournal(j journal.Journal) (dbJournal, error) {
|
||||
|
||||
@@ -66,6 +66,17 @@ func Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName string)
|
||||
fmt.Sprintf(`ALTER TABLE %s_roles ALTER COLUMN updated_at TYPE TIMESTAMP;`, rolesTableNamePrefix),
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: fmt.Sprintf("%s_roles_3", rolesTableNamePrefix),
|
||||
Up: []string{
|
||||
fmt.Sprintf(`CREATE INDEX IF NOT EXISTS idx_%s_role_members_member_id ON %s_role_members(member_id);`, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
fmt.Sprintf(`CREATE INDEX IF NOT EXISTS idx_%s_role_actions_action ON %s_role_actions(action text_pattern_ops);`, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
},
|
||||
Down: []string{
|
||||
fmt.Sprintf(`DROP INDEX IF EXISTS idx_%s_role_members_member_id;`, rolesTableNamePrefix),
|
||||
fmt.Sprintf(`DROP INDEX IF EXISTS idx_%s_role_actions_action;`, rolesTableNamePrefix),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -694,7 +694,7 @@ func (repo *Repository) RoleRemoveAllMembers(ctx context.Context, role roles.Rol
|
||||
|
||||
dbrcap := dbRoleAction{RoleID: role.ID}
|
||||
|
||||
if _, err := repo.db.NamedExecContext(ctx, q, dbrcap); err != nil {
|
||||
if _, err := tx.NamedExec(q, dbrcap); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
@@ -1113,7 +1113,7 @@ ungrouped_members AS (
|
||||
ARRAY_AGG(DISTINCT cra."action") AS actions,
|
||||
'direct' AS access_type,
|
||||
'' AS access_provider_id,
|
||||
''::::LTREE AS access_provider_path
|
||||
CAST('' AS LTREE) AS access_provider_path
|
||||
FROM
|
||||
client_group cg
|
||||
JOIN
|
||||
@@ -1180,7 +1180,7 @@ ungrouped_members AS (
|
||||
ARRAY_AGG(DISTINCT agg_dra."action") AS actions,
|
||||
'domain' AS access_type,
|
||||
d.id AS access_provider_id,
|
||||
''::::LTREE AS access_provider_path
|
||||
CAST('' AS LTREE) AS access_provider_path
|
||||
FROM
|
||||
client_group cg
|
||||
JOIN
|
||||
@@ -1251,7 +1251,7 @@ ungrouped_members AS (
|
||||
ARRAY_AGG(DISTINCT cra."action") AS actions,
|
||||
'direct' AS access_type,
|
||||
'' AS access_provider_id,
|
||||
''::::LTREE AS access_provider_path
|
||||
CAST('' AS LTREE) AS access_provider_path
|
||||
FROM
|
||||
channel_group cg
|
||||
JOIN
|
||||
@@ -1318,7 +1318,7 @@ ungrouped_members AS (
|
||||
ARRAY_AGG(DISTINCT agg_dra."action") AS actions,
|
||||
'domain' AS access_type,
|
||||
d.id AS access_provider_id,
|
||||
''::::LTREE AS access_provider_path
|
||||
CAST('' AS LTREE) AS access_provider_path
|
||||
FROM
|
||||
channel_group cg
|
||||
JOIN
|
||||
|
||||
+76
-69
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/postgres"
|
||||
"github.com/absmach/supermq/users"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type userRepo struct {
|
||||
@@ -122,57 +123,64 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
|
||||
}
|
||||
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
|
||||
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at
|
||||
FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
|
||||
var items []users.User
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if pm.OnlyTotal {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrRetrieveAllUsers, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return users.UsersPage{
|
||||
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
|
||||
}, nil
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
|
||||
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at,
|
||||
COUNT(*) OVER() AS total_count
|
||||
FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrRetrieveAllUsers, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []users.User
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
total = dbu.TotalCount
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
Page: users.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
return users.UsersPage{
|
||||
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
|
||||
Users: items,
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *userRepo) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
|
||||
@@ -317,10 +325,10 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
|
||||
}
|
||||
|
||||
tq := query
|
||||
query = applyOrdering(query, pm)
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
q := fmt.Sprintf(`SELECT u.id, u.username, u.metadata, u.first_name, u.last_name, u.created_at, u.updated_at FROM users u %s LIMIT :limit OFFSET :offset;`, query)
|
||||
q := fmt.Sprintf(`SELECT u.id, u.username, u.metadata, u.first_name, u.last_name, u.created_at, u.updated_at,
|
||||
COUNT(*) OVER() AS total_count FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
@@ -333,12 +341,14 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []users.User
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
total = dbu.TotalCount
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
@@ -348,23 +358,18 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, tq)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
return users.UsersPage{
|
||||
Users: items,
|
||||
Page: users.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (users.UsersPage, error) {
|
||||
@@ -380,7 +385,8 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
q := fmt.Sprintf(`SELECT u.id, u.username, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name,
|
||||
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by,
|
||||
COUNT(*) OVER() AS total_count FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
@@ -391,12 +397,14 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var total uint64
|
||||
var items []users.User
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
total = dbu.TotalCount
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
@@ -405,23 +413,19 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
if len(items) == 0 {
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
|
||||
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
return users.UsersPage{
|
||||
Users: items,
|
||||
Page: users.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.User, error) {
|
||||
@@ -498,6 +502,7 @@ type DBUser struct {
|
||||
Email string `db:"email,omitempty"`
|
||||
VerifiedAt sql.NullTime `db:"verified_at,omitempty"`
|
||||
AuthProvider sql.NullString `db:"auth_provider,omitempty"`
|
||||
TotalCount uint64 `db:"total_count"`
|
||||
}
|
||||
|
||||
func toDBUser(u users.User) (DBUser, error) {
|
||||
@@ -634,6 +639,7 @@ type DBUsersPage struct {
|
||||
GroupID string `db:"group_id"`
|
||||
Role users.Role `db:"role"`
|
||||
Status users.Status `db:"status"`
|
||||
IDs pq.StringArray `db:"ids"`
|
||||
CreatedFrom time.Time `db:"created_from"`
|
||||
CreatedTo time.Time `db:"created_to"`
|
||||
}
|
||||
@@ -662,6 +668,7 @@ func ToDBUsersPage(pm users.Page) (DBUsersPage, error) {
|
||||
Status: pm.Status,
|
||||
Tags: tags,
|
||||
Role: pm.Role,
|
||||
IDs: pq.StringArray(pm.IDs),
|
||||
CreatedFrom: pm.CreatedFrom,
|
||||
CreatedTo: pm.CreatedTo,
|
||||
}, nil
|
||||
@@ -699,7 +706,7 @@ func PageQuery(pm users.Page) (string, error) {
|
||||
query = append(query, "metadata @> :metadata")
|
||||
}
|
||||
if len(pm.IDs) != 0 {
|
||||
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
query = append(query, "id = ANY(:ids)")
|
||||
}
|
||||
if pm.Status != users.AllStatus {
|
||||
query = append(query, "u.status = :status")
|
||||
|
||||
Reference in New Issue
Block a user