NOISSUE - Improve SQL queries performance and safety (#3378)
Continuous Delivery / lint-and-build (push) Has been cancelled
Deploy GitHub Pages / swagger-ui (push) Has been cancelled
CI Pipeline / Check Certs (push) Has been cancelled
CI Pipeline / Lint Proto (push) Has been cancelled
CI Pipeline / Detect Changes (push) Has been cancelled
Continuous Delivery / Build and Push Docker Images (push) Has been cancelled
CI Pipeline / lint-and-build (push) Has been cancelled
CI Pipeline / Test ${{ matrix.module }} (push) Has been cancelled
CI Pipeline / Upload Coverage (push) Has been cancelled

Signed-off-by: dusan <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2026-03-06 11:09:40 +01:00
committed by GitHub
parent cef4b1d14d
commit abd669c610
16 changed files with 696 additions and 519 deletions
+9
View File
@@ -134,6 +134,15 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE pat_scopes RENAME COLUMN domain_id TO optional_domain_id;`,
},
},
{
Id: "auth_8",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_pats_user_id ON pats(user_id);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_pats_user_id;`,
},
},
},
}
}
+1
View File
@@ -23,6 +23,7 @@ type dbPat struct {
Revoked bool `db:"revoked,omitempty"`
RevokedAt sql.NullTime `db:"revoked_at,omitempty"`
Status auth.Status `db:"status,omitempty"`
TotalCount uint64 `db:"total_count"`
}
type dbScope struct {
+44 -42
View File
@@ -72,14 +72,15 @@ func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSP
}
q := fmt.Sprintf(`
SELECT
SELECT
p.id, p.user_id, p.name, p.description, p.issued_at, p.expires_at,
p.updated_at, p.revoked, p.revoked_at,
CASE
CASE
WHEN p.revoked = TRUE THEN %d
WHEN expires_at IS NOT NULL AND expires_at < :timestamp THEN %d
ELSE %d
END AS status
END AS status,
COUNT(*) OVER() AS total_count
FROM pats p WHERE user_id = :user_id %s
ORDER BY issued_at DESC
LIMIT :limit OFFSET :offset`, auth.RevokedStatus, auth.ExpiredStatus, auth.ActiveStatus, pageQuery)
@@ -100,30 +101,31 @@ func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSP
}
defer rows.Close()
var total uint64
items := []auth.PAT{}
for rows.Next() {
var pat dbPat
if err := rows.StructScan(&pat); err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
total = pat.TotalCount
items = append(items, toAuthPat(pat))
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM pats p WHERE user_id = :user_id %s`, pageQuery)
total, err := postgres.Total(ctx, pr.db, cq, dbPage)
if err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM pats p WHERE user_id = :user_id %s`, pageQuery)
total, err = postgres.Total(ctx, pr.db, cq, dbPage)
if err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
}
page := auth.PATSPage{
return auth.PATSPage{
PATS: items,
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
}
return page, nil
}, nil
}
func PageQuery(pm auth.PATSPageMeta) (string, error) {
@@ -408,14 +410,15 @@ func (pr *patRepo) AddScope(ctx context.Context, userID string, scopes []auth.Sc
func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope, error) {
q := `
SELECT COUNT(*)
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND domain_id = :domain_id
AND operation = :operation
AND entity_id = :entity_id
LIMIT 1`
SELECT EXISTS (
SELECT 1
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND domain_id = :domain_id
AND operation = :operation
AND entity_id = :entity_id
)`
params := dbScope{
PatID: sc.PatID,
@@ -431,18 +434,28 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
}
defer rows.Close()
var count int
var exists bool
if rows.Next() {
if err := rows.Scan(&count); err != nil {
if err := rows.Scan(&exists); err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
}
if count > 0 {
if exists {
return auth.Scope{}, repoerr.ErrConflict
}
if sc.EntityID == auth.AnyIDs {
checkEntityQuery := `
SELECT EXISTS (
SELECT 1
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND domain_id = :domain_id
AND operation = :operation
)`
newParams := dbScope{
PatID: sc.PatID,
DomainID: sc.DomainID,
@@ -450,42 +463,31 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
Operation: sc.Operation,
}
checkEntityQuery := `
SELECT COUNT(*)
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND domain_id = :domain_id
AND operation = :operation
LIMIT 1`
rows, err := pr.db.NamedQueryContext(ctx, checkEntityQuery, newParams)
if err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
var count int
var scopeExists bool
if rows.Next() {
if err := rows.Scan(&count); err != nil {
if err := rows.Scan(&scopeExists); err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
}
if count > 0 {
if scopeExists {
updateWithWildcardQuery := `
UPDATE pat_scopes
SET entity_id = :entity_id
WHERE pat_id = :pat_id
UPDATE pat_scopes
SET entity_id = :entity_id
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND domain_id = :domain_id
AND operation = :operation`
rows, err = pr.db.NamedQueryContext(ctx, updateWithWildcardQuery, params)
if err != nil {
if _, err := pr.db.NamedExecContext(ctx, updateWithWildcardQuery, params); err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
return auth.Scope{}, nil
}
@@ -657,7 +659,7 @@ func (pr *patRepo) retrievePATFromDB(ctx context.Context, userID, patID string)
WHEN revoked = TRUE THEN %d
WHEN expires_at IS NOT NULL AND expires_at < :timestamp THEN %d
ELSE %d
END AS status
END AS status
FROM pats WHERE user_id = :user_id AND id = :id`, auth.RevokedStatus, auth.ExpiredStatus, auth.ActiveStatus)
dbp := dbPagemeta{
+102 -94
View File
@@ -183,7 +183,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
SELECT
c.id,
c.parent_group_id,
COALESCE(g."path", ''::::ltree) AS parent_group_path,
COALESCE(g."path", CAST('' AS ltree)) AS parent_group_path,
c.domain_id
FROM
channels c
@@ -201,7 +201,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
cr."name" AS role_name,
jsonb_agg(DISTINCT cra."action") AS actions,
'direct' AS access_type,
''::::ltree AS access_provider_path,
CAST('' AS ltree) AS access_provider_path,
'' AS access_provider_id
FROM
channels_roles cr
@@ -260,7 +260,7 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
dr.id AS role_id,
dr."name" AS role_name,
jsonb_agg(DISTINCT all_actions."action") AS actions,
''::::ltree access_provider_path,
CAST('' AS ltree) access_provider_path,
'domain' AS access_type,
dr.entity_id AS access_provider_id
FROM
@@ -414,7 +414,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
COALESCE(c.domain_id, '') AS domain_id,
COALESCE(parent_group_id, '') AS parent_group_id,
c.route,
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
COALESCE(g.path, CAST('' AS ltree)) AS parent_group_path,
c.status,
c.created_by,
c.created_at,
@@ -422,6 +422,8 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
COALESCE(c.updated_by, '') AS updated_by
FROM
channels c
LEFT JOIN
groups g ON g.id = c.parent_group_id
)
SELECT
c.*
@@ -492,7 +494,7 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
return channels.ChannelsPage{}, err
}
bq := repo.userChannelsBaseQuery(domainID, userID)
bq := userChannelsBaseQuery
connJoinQuery := `
FROM
@@ -517,6 +519,34 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
`
}
dbPage, err := toDBChannelsPage(pm)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
dbPage.UserID = userID
dbPage.DomainID = domainID
if pm.OnlyTotal {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_channels c
%s;
`, bq, pageQuery)
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return channels.ChannelsPage{
Page: channels.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}, nil
}
q := fmt.Sprintf(`
%s
SELECT
@@ -540,7 +570,8 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
c.access_provider_role_actions,
COUNT(*) OVER() AS total_count
%s
%s
`, bq, connJoinQuery, pageQuery)
@@ -549,83 +580,54 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
q = applyLimitOffset(q)
dbPage, err := toDBChannelsPage(pm)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
var items []channels.Channel
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
for rows.Next() {
dbc := dbChannel{}
if err := rows.StructScan(&dbc); err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := toChannel(dbc)
if err != nil {
return channels.ChannelsPage{}, err
}
items = append(items, c)
}
}
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.route,
c.tags,
c.metadata,
c.created_by,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
%s
%s
) AS subquery;
`, bq, connJoinQuery, pageQuery)
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
page := channels.ChannelsPage{
var total uint64
var items []channels.Channel
for rows.Next() {
dbc := dbChannel{}
if err := rows.StructScan(&dbc); err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = dbc.TotalCount
c, err := toChannel(dbc)
if err != nil {
return channels.ChannelsPage{}, err
}
items = append(items, c)
}
if len(items) == 0 {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_channels c
%s;
`, bq, pageQuery)
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
return channels.ChannelsPage{
Channels: items,
Page: channels.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
}, nil
}
func (repo *channelRepository) userChannelsBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
const userChannelsBaseQuery = `
WITH direct_channels AS (
select
c.id,
@@ -640,7 +642,7 @@ WITH direct_channels AS (
c.updated_at,
c.updated_by,
c.status,
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
COALESCE(pg.path, CAST('' AS ltree)) AS parent_group_path,
cr.id AS role_id,
cr."name" AS role_name,
array_agg(cra."action") AS actions,
@@ -648,7 +650,7 @@ WITH direct_channels AS (
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
CAST(array[] AS text[]) AS access_provider_role_actions
FROM
channels_role_members crm
JOIN
@@ -657,11 +659,13 @@ WITH direct_channels AS (
channels_roles cr ON cr.id = crm.role_id
JOIN
channels c ON c.id = cr.entity_id
LEFT JOIN
groups pg ON pg.id = c.parent_group_id
WHERE
crm.member_id = '%s'
AND c.domain_id = '%s'
crm.member_id = :user_id
AND c.domain_id = :domain_id_param
GROUP BY
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
cr.entity_id, crm.member_id, cr.id, cr."name", c.id, pg.path
),
direct_groups AS (
SELECT
@@ -682,9 +686,9 @@ direct_groups AS (
JOIN
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
AND gra."action" LIKE 'channel%%'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
AND gra."action" LIKE 'channel%'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
@@ -707,9 +711,9 @@ direct_groups_with_subgroup AS (
JOIN
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
AND gra."action" LIKE 'subgroup_channel%%'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
AND gra."action" LIKE 'subgroup_channel%'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
@@ -737,7 +741,7 @@ indirect_child_groups AS (
JOIN
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
WHERE
indirect_child_groups.domain_id = '%s'
indirect_child_groups.domain_id = :domain_id_param
AND NOT EXISTS (
SELECT 1
FROM direct_groups_with_subgroup dgws
@@ -759,7 +763,7 @@ final_groups AS (
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'direct_group' AS access_type,
id AS access_provider_id,
role_id AS access_provider_role_id,
@@ -782,7 +786,7 @@ final_groups AS (
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'indirect_group' AS access_type,
access_provider_id,
access_provider_role_id,
@@ -819,7 +823,7 @@ groups_channels AS (
JOIN
channels c ON c.parent_group_id = g.id
WHERE
c.id NOT IN (SELECT id FROM direct_channels)
NOT EXISTS (SELECT 1 FROM direct_channels dc WHERE dc.id = c.id)
UNION
SELECT * FROM direct_channels
),
@@ -865,7 +869,7 @@ final_channels AS (
g."path" AS parent_group_path,
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'domain' AS access_type,
d.id AS access_provider_id,
dr.id AS access_provider_role_id,
@@ -884,18 +888,17 @@ final_channels AS (
LEFT JOIN
groups g ON dc.parent_group_id = g.id
WHERE
drm.member_id = '%s' -- user_id
AND d.id = '%s' -- domain_id
AND dra."action" LIKE 'channel_%%'
AND NOT EXISTS ( -- Ensures that the direct and indirect channels are not in included.
drm.member_id = :user_id
AND d.id = :domain_id_param
AND dra."action" LIKE 'channel_%'
AND NOT EXISTS (
SELECT 1 FROM groups_channels gc
WHERE gc.id = dc.id
)
GROUP BY
dc.id, d.id, dr.id, g."path"
)
`, userID, domainID, userID, domainID, userID, domainID, domainID, userID, domainID)
}
`
func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
q := "DELETE FROM channels AS c WHERE c.id = ANY(:channel_ids) ;"
@@ -1145,6 +1148,7 @@ type dbChannel struct {
ConnectionTypes pq.Int32Array `db:"connection_types,omitempty"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
TotalCount uint64 `db:"total_count"`
}
func toDBChannel(ch channels.Channel) (dbChannel, error) {
@@ -1303,7 +1307,7 @@ func PageQuery(pm channels.Page) (string, error) {
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
query = append(query, "id = ANY(:ids)")
}
if pm.Status != channels.AllStatus {
query = append(query, "c.status = :status")
@@ -1415,6 +1419,7 @@ func toDBChannelsPage(pm channels.Page) (dbChannelsPage, error) {
RoleID: pm.RoleID,
Actions: pm.Actions,
AccessType: pm.AccessType,
IDs: pq.StringArray(pm.IDs),
CreatedFrom: pm.CreatedFrom,
CreatedTo: pm.CreatedTo,
}, nil
@@ -1438,6 +1443,9 @@ type dbChannelsPage struct {
AccessType string `db:"access_type"`
CreatedFrom time.Time `db:"created_from"`
CreatedTo time.Time `db:"created_to"`
IDs pq.StringArray `db:"ids"`
UserID string `db:"user_id"`
DomainID string `db:"domain_id_param"`
}
type dbConnection struct {
+11
View File
@@ -95,6 +95,17 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
`SELECT 1`,
},
},
{
Id: "channels_06",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_channels_domain_id_status ON channels(domain_id, status);`,
`CREATE INDEX IF NOT EXISTS idx_channels_parent_group_id ON channels(parent_group_id);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_channels_domain_id_status;`,
`DROP INDEX IF EXISTS idx_channels_parent_group_id;`,
},
},
},
}
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
+105 -98
View File
@@ -195,7 +195,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
SELECT
c.id,
c.parent_group_id,
COALESCE(g."path", ''::::ltree) AS parent_group_path,
COALESCE(g."path", CAST('' AS ltree)) AS parent_group_path,
c.domain_id
FROM
clients c
@@ -213,7 +213,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
cr."name" AS role_name,
jsonb_agg(DISTINCT cra."action") AS actions,
'direct' AS access_type,
''::::ltree AS access_provider_path,
CAST('' AS ltree) AS access_provider_path,
'' AS access_provider_id
FROM
clients_roles cr
@@ -272,7 +272,7 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
dr.id AS role_id,
dr."name" AS role_name,
jsonb_agg(DISTINCT all_actions."action") AS actions,
''::::ltree access_provider_path,
CAST('' AS ltree) access_provider_path,
'domain' AS access_type,
dr.entity_id AS access_provider_id
FROM
@@ -452,13 +452,15 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
c.metadata,
COALESCE(c.domain_id, '') AS domain_id,
COALESCE(parent_group_id, '') AS parent_group_id,
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
COALESCE(g.path, CAST('' AS ltree)) AS parent_group_path,
c.status,
c.created_at,
c.updated_at,
COALESCE(c.updated_by, '') AS updated_by
FROM
clients c
LEFT JOIN
groups g ON g.id = c.parent_group_id
)
SELECT
c.*
@@ -529,7 +531,7 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
return clients.ClientsPage{}, err
}
bq := repo.userClientBaseQuery(domainID, userID)
bq := userClientBaseQuery
connJoinQuery := `
FROM
@@ -554,6 +556,34 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
`
}
dbPage, err := ToDBClientsPage(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
dbPage.UserID = userID
dbPage.DomainID = domainID
if pm.OnlyTotal {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_clients c
%s;
`, bq, pageQuery)
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return clients.ClientsPage{
Page: clients.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}, nil
}
q := fmt.Sprintf(`
%s
SELECT
@@ -577,7 +607,8 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
c.access_provider_role_actions,
COUNT(*) OVER() AS total_count
%s
%s
`, bq, connJoinQuery, pageQuery)
@@ -586,83 +617,54 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
q = applyLimitOffset(q)
dbPage, err := ToDBClientsPage(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
var items []clients.Client
if !pm.OnlyTotal {
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
if err != nil {
return clients.ClientsPage{}, err
}
items = append(items, c)
}
}
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
%s
%s
) AS subquery;
`, bq, connJoinQuery, pageQuery)
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
page := clients.ClientsPage{
var total uint64
var items []clients.Client
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = dbc.TotalCount
c, err := ToClient(dbc)
if err != nil {
return clients.ClientsPage{}, err
}
items = append(items, c)
}
if len(items) == 0 {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_clients c
%s;
`, bq, pageQuery)
total, err = postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
return clients.ClientsPage{
Clients: items,
Page: clients.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
}, nil
}
func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
const userClientBaseQuery = `
WITH direct_clients AS (
SELECT
c.id,
@@ -677,7 +679,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
c.updated_at,
c.updated_by,
c.status,
COALESCE((SELECT path FROM groups WHERE id = c.parent_group_id), ''::::ltree) AS parent_group_path,
COALESCE(pg.path, CAST('' AS ltree)) AS parent_group_path,
cr.id AS role_id,
cr."name" AS role_name,
array_agg(cra."action") AS actions,
@@ -685,7 +687,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
CAST(array[] AS text[]) AS access_provider_role_actions
FROM
clients_role_members crm
JOIN
@@ -694,11 +696,13 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
clients_roles cr ON cr.id = crm.role_id
JOIN
clients c ON c.id = cr.entity_id
LEFT JOIN
groups pg ON pg.id = c.parent_group_id
WHERE
crm.member_id = '%s'
AND c.domain_id = '%s'
crm.member_id = :user_id
AND c.domain_id = :domain_id_param
GROUP BY
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
cr.entity_id, crm.member_id, cr.id, cr."name", c.id, pg.path
),
direct_groups AS (
SELECT
@@ -719,9 +723,9 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
JOIN
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
AND gra."action" LIKE 'client%%'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
AND gra."action" LIKE 'client%'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
@@ -744,9 +748,9 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
JOIN
groups_role_actions all_actions ON all_actions.role_id = grm.role_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
AND gra."action" LIKE 'subgroup_client%%'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
AND gra."action" LIKE 'subgroup_client%'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
@@ -772,10 +776,10 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
FROM
direct_leaf_groups_with_subgroup dlgws
JOIN
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path -- Finds all children of entity_id based on ltree path
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
WHERE
indirect_child_groups.domain_id = '%s'
AND NOT EXISTS (
indirect_child_groups.domain_id = :domain_id_param
AND NOT EXISTS (
SELECT 1
FROM direct_groups_with_subgroup dgws
WHERE dgws.id = indirect_child_groups.id
@@ -796,7 +800,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'direct_group' AS access_type,
id AS access_provider_id,
role_id AS access_provider_role_id,
@@ -819,7 +823,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'indirect_group' AS access_type,
access_provider_id,
access_provider_role_id,
@@ -856,7 +860,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
JOIN
clients c ON c.parent_group_id = g.id
WHERE
c.id NOT IN (SELECT id FROM direct_clients)
NOT EXISTS (SELECT 1 FROM direct_clients dc WHERE dc.id = c.id)
UNION
SELECT * FROM direct_clients
),
@@ -902,7 +906,7 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
g."path" AS parent_group_path,
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'domain' AS access_type,
d.id AS access_provider_id,
dr.id AS access_provider_role_id,
@@ -921,18 +925,17 @@ func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
LEFT JOIN
groups g ON dc.parent_group_id = g.id
WHERE
drm.member_id = '%s' -- user_id
AND d.id = '%s' -- domain_id
AND dra."action" LIKE 'client_%%'
AND NOT EXISTS ( -- Ensures that the direct and indirect clients are not in included.
drm.member_id = :user_id
AND d.id = :domain_id_param
AND dra."action" LIKE 'client_%'
AND NOT EXISTS (
SELECT 1 FROM groups_clients gc
WHERE gc.id = dc.id
)
GROUP BY
dc.id, d.id, dr.id, g."path"
)
`, userID, domainID, userID, domainID, userID, domainID, domainID, userID, domainID)
}
`
func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
query, err := PageQuery(pm)
@@ -1056,6 +1059,7 @@ type DBClient struct {
ConnectionTypes pq.Int32Array `db:"connection_types,omitempty"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
TotalCount uint64 `db:"total_count"`
}
func ToDBClient(c clients.Client) (DBClient, error) {
@@ -1206,6 +1210,7 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) {
RoleID: pm.RoleID,
Actions: pm.Actions,
AccessType: pm.AccessType,
IDs: pq.StringArray(pm.IDs),
CreatedFrom: pm.CreatedFrom,
CreatedTo: pm.CreatedTo,
}, nil
@@ -1230,6 +1235,9 @@ type dbClientsPage struct {
AccessType string `db:"access_type"`
CreatedFrom time.Time `db:"created_from"`
CreatedTo time.Time `db:"created_to"`
IDs pq.StringArray `db:"ids"`
UserID string `db:"user_id"`
DomainID string `db:"domain_id_param"`
}
func PageQuery(pm clients.Page) (string, error) {
@@ -1252,7 +1260,7 @@ func PageQuery(pm clients.Page) (string, error) {
}
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("c.id IN ('%s')", strings.Join(pm.IDs, "','")))
query = append(query, "c.id = ANY(:ids)")
}
if pm.Status != clients.AllStatus {
@@ -1432,9 +1440,8 @@ func (repo *clientRepo) RemoveConnections(ctx context.Context, conns []clients.C
}
}()
query := `DELETE FROM connections WHERE channel_id = :channel_id AND domain_id = :domain_id AND client_id = :client_id`
for _, conn := range conns {
query := `DELETE FROM connections WHERE channel_id = :channel_id AND domain_id = :domain_id AND client_id = :client_id`
if uint8(conn.Type) > 0 {
query = query + " AND type = :type "
}
+13
View File
@@ -99,6 +99,19 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
`SELECT 1`,
},
},
{
Id: "clients_06",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_clients_domain_id_status ON clients(domain_id, status);`,
`CREATE INDEX IF NOT EXISTS idx_clients_parent_group_id ON clients(parent_group_id);`,
`CREATE INDEX IF NOT EXISTS idx_connections_client_id ON connections(client_id);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_clients_domain_id_status;`,
`DROP INDEX IF EXISTS idx_clients_parent_group_id;`,
`DROP INDEX IF EXISTS idx_connections_client_id;`,
},
},
},
}
+90 -76
View File
@@ -20,7 +20,6 @@ import (
"github.com/absmach/supermq/pkg/roles"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
@@ -252,7 +251,6 @@ func (repo domainRepo) RetrieveDomainByRoute(ctx context.Context, route string)
// RetrieveAllByIDs retrieves for given Domain IDs .
func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
var q string
if len(pm.IDs) == 0 {
return domains.DomainsPage{}, nil
}
@@ -263,11 +261,11 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
baseQ := `SELECT d.id as id, d.name as name, d.tags as tags, d.route as route, d.metadata as metadata,
d.created_at as created_at, d.updated_at as updated_at, d.updated_by as updated_by,
d.created_by as created_by, d.status as status FROM domains d`
d.created_by as created_by, d.status as status, COUNT(*) OVER() AS total_count FROM domains d`
squery := applyOrdering(query, pm)
q = fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", baseQ, squery, pm.Limit, pm.Offset)
q := fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", baseQ, squery, pm.Limit, pm.Offset)
dbPage, err := toDBDomainsPage(pm)
if err != nil {
@@ -280,19 +278,30 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
}
defer rows.Close()
doms, err := repo.processRows(rows)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
var total uint64
var doms []domains.Domain
for rows.Next() {
dbd := dbDomain{}
if err := rows.StructScan(&dbd); err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
total = dbd.TotalCount
d, err := toDomain(dbd)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
doms = append(doms, d)
}
cq := "SELECT COUNT(*) FROM domains d"
if query != "" {
cq = fmt.Sprintf(" %s %s", cq, query)
}
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
if len(doms) == 0 {
cq := "SELECT COUNT(*) FROM domains d"
if query != "" {
cq = fmt.Sprintf(" %s %s", cq, query)
}
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
return domains.DomainsPage{
@@ -321,14 +330,15 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
d.updated_at as updated_at,
d.updated_by as updated_by,
d.created_by as created_by,
d.status as status
d.status as status,
COUNT(*) OVER() AS total_count
FROM
domains as d
%s
LIMIT :limit OFFSET :offset`
if pm.UserID != "" {
q = repo.userDomainsBaseQuery() +
q = userDomainsBaseQuery +
`
SELECT
d.id as id,
@@ -343,7 +353,8 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
d.created_at as created_at,
d.updated_at as updated_at,
d.updated_by as updated_by,
d.created_by as created_by
d.created_by as created_by,
COUNT(*) OVER() AS total_count
FROM
domains d
%s
@@ -358,35 +369,55 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
var doms []domains.Domain
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
doms, err = repo.processRows(rows)
if pm.OnlyTotal {
cq := `SELECT COUNT(*) FROM domains as d %s`
if pm.UserID != "" {
cq = userDomainsBaseQuery + cq
}
if query != "" {
cq = fmt.Sprintf(cq, query)
}
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
return domains.DomainsPage{Total: total, Offset: pm.Offset, Limit: pm.Limit}, nil
}
cq := `SELECT COUNT(*)
FROM domains as d %s`
if pm.UserID != "" {
cq = repo.userDomainsBaseQuery() + cq
}
if query != "" {
cq = fmt.Sprintf(cq, query)
}
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
var total uint64
var doms []domains.Domain
for rows.Next() {
dbd := dbDomain{}
if err := rows.StructScan(&dbd); err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
total = dbd.TotalCount
d, err := toDomain(dbd)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
doms = append(doms, d)
}
if len(doms) == 0 {
cq := `SELECT COUNT(*) FROM domains as d %s`
if pm.UserID != "" {
cq = userDomainsBaseQuery + cq
}
if query != "" {
cq = fmt.Sprintf(cq, query)
}
total, err = postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
return domains.DomainsPage{
Total: total,
@@ -479,8 +510,7 @@ func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error {
return nil
}
func (repo domainRepo) userDomainsBaseQuery() string {
return `
const userDomainsBaseQuery = `
with domains AS (
SELECT
d.id as id,
@@ -511,23 +541,6 @@ func (repo domainRepo) userDomainsBaseQuery() string {
GROUP BY
dr.entity_id, drm.member_id, dr.id, dr."name", d.id
)`
}
func (repo domainRepo) processRows(rows *sqlx.Rows) ([]domains.Domain, error) {
var items []domains.Domain
for rows.Next() {
dbd := dbDomain{}
if err := rows.StructScan(&dbd); err != nil {
return items, err
}
d, err := toDomain(dbd)
if err != nil {
return items, err
}
items = append(items, d)
}
return items, nil
}
func applyOrdering(emq string, pm domains.Page) string {
col := "COALESCE(d.updated_at, d.created_at)"
@@ -550,21 +563,22 @@ func applyOrdering(emq string, pm domains.Page) string {
}
type dbDomain struct {
ID string `db:"id"`
Name string `db:"name"`
Metadata []byte `db:"metadata,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Route *string `db:"route,omitempty"`
Status domains.Status `db:"status"`
RoleID string `db:"role_id"`
RoleName string `db:"role_name"`
Actions pq.StringArray `db:"actions"`
CreatedBy string `db:"created_by"`
CreatedAt time.Time `db:"created_at"`
UpdatedBy *string `db:"updated_by,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
ID string `db:"id"`
Name string `db:"name"`
Metadata []byte `db:"metadata,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Route *string `db:"route,omitempty"`
Status domains.Status `db:"status"`
RoleID string `db:"role_id"`
RoleName string `db:"role_name"`
Actions pq.StringArray `db:"actions"`
CreatedBy string `db:"created_by"`
CreatedAt time.Time `db:"created_at"`
UpdatedBy *string `db:"updated_by,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
TotalCount uint64 `db:"total_count"`
}
func toDBDomain(d domains.Domain) (dbDomain, error) {
@@ -670,7 +684,7 @@ type dbDomainsPage struct {
RoleName string `db:"role_name"`
Actions pq.StringArray `db:"actions"`
ID string `db:"id"`
IDs []string `db:"ids"`
IDs pq.StringArray `db:"ids"`
Metadata []byte `db:"metadata"`
Tags pgtype.TextArray `db:"tags"`
Status domains.Status `db:"status"`
@@ -699,7 +713,7 @@ func toDBDomainsPage(pm domains.Page) (dbDomainsPage, error) {
RoleName: pm.RoleName,
Actions: pm.Actions,
ID: pm.ID,
IDs: pm.IDs,
IDs: pq.StringArray(pm.IDs),
Metadata: data,
Tags: tags,
Status: pm.Status,
@@ -718,7 +732,7 @@ func buildPageQuery(pm domains.Page) (string, error) {
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("d.id IN ('%s')", strings.Join(pm.IDs, "','")))
query = append(query, "d.id = ANY(:ids)")
}
if (pm.Status >= domains.EnabledStatus) && (pm.Status < domains.AllStatus) {
+11
View File
@@ -160,6 +160,17 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
END $$;`,
},
},
{
Id: "domain_6",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_invitations_invited_by ON invitations(invited_by);`,
`CREATE INDEX IF NOT EXISTS idx_invitations_role_id ON invitations(role_id);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_invitations_invited_by;`,
`DROP INDEX IF EXISTS idx_invitations_role_id;`,
},
},
{
Id: "domain_7",
Up: []string{
+188 -125
View File
@@ -60,7 +60,7 @@ func New(db postgres.Database) groups.Repository {
}
func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Group, error) {
q, err := repo.getInsertQuery(ctx, g)
q, computedPath, err := repo.getInsertQuery(ctx, g)
if err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
@@ -68,6 +68,9 @@ func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Gr
if err != nil {
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
if computedPath != "" {
dbg.Path = computedPath
}
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
@@ -256,7 +259,7 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
dr.id AS role_id,
dr.name AS role_name,
jsonb_agg(DISTINCT all_actions.action) AS actions,
''::::ltree access_provider_path,
CAST('' AS ltree) access_provider_path,
'domain' AS access_type,
dr.entity_id AS access_provider_id
FROM
@@ -362,9 +365,9 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
}
func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, userID, groupID string) (groups.Group, error) {
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
baseQuery := userGroupsBaseQuery
dbg := dbGroup{ID: groupID}
dbg := dbGroup{ID: groupID, UserID: userID, DomainIDParam: domainID}
q := fmt.Sprintf(`%s
SELECT
g.id,
@@ -438,39 +441,54 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
orderClause = fmt.Sprintf("ORDER BY %s %s, g.id %s", orderBy, dir, dir)
}
q := fmt.Sprintf(`SELECT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
dbPageMeta, err := toDBGroupPageMeta(pm)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
var items []groups.Group
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err = repo.processRows(rows)
if pm.OnlyTotal {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM groups g %s;`, query)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
page := groups.Page{PageMeta: pm}
page.Total = total
return page, nil
}
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
FROM (
SELECT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
) AS subquery;
`, query)
q := fmt.Sprintf(`SELECT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status,
COUNT(*) OVER() AS total_count FROM groups g %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
var total uint64
var items []groups.Group
for rows.Next() {
dbg := dbGroup{}
if err := rows.StructScan(&dbg); err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
total = dbg.TotalCount
g, err := toGroup(dbg)
if err != nil {
return groups.Page{}, err
}
items = append(items, g)
}
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM groups g %s;`, query)
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
page := groups.Page{PageMeta: pm}
page.Total = total
@@ -479,40 +497,49 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
}
func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMeta, ids ...string) (groups.Page, error) {
var q string
if (len(ids) == 0) && (pm.DomainID == "") {
return groups.Page{PageMeta: groups.PageMeta{Offset: pm.Offset, Limit: pm.Limit}}, nil
}
query := buildQuery(pm, ids...)
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
q := fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status,
COUNT(*) OVER() AS total_count FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
dbPageMeta, err := toDBGroupPageMeta(pm)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
dbPageMeta.IDs = pq.StringArray(ids)
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err := repo.processRows(rows)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
var total uint64
var items []groups.Group
for rows.Next() {
dbg := dbGroup{}
if err := rows.StructScan(&dbg); err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
total = dbg.TotalCount
g, err := toGroup(dbg)
if err != nil {
return groups.Page{}, err
}
items = append(items, g)
}
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
FROM (
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
) AS subquery;
`, query)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM (
SELECT DISTINCT g.id FROM groups g %s
) AS subquery;`, query)
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
page := groups.Page{PageMeta: pm}
@@ -530,7 +557,7 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
dirQuery = "g.path <@ (SELECT path FROM groups WHERE id = :id)"
}
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
baseQuery := userGroupsBaseQuery
query := fmt.Sprintf(`%s,
target_hierarchy AS (
SELECT
@@ -568,8 +595,10 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
`, baseQuery, dirQuery)
parameters := map[string]any{
"id": groupID,
"level": hm.Level,
"id": groupID,
"level": hm.Level,
"user_id": userID,
"domain_id_param": domainID,
}
rows, err := repo.db.NamedQueryContext(ctx, query, parameters)
@@ -798,7 +827,7 @@ func (repo groupRepository) RetrieveAllParentGroups(ctx context.Context, domainI
query := buildQuery(pm)
levelCondition := fmt.Sprintf("g.path @> '%s' ", cGroup.Path)
levelCondition := "g.path @> CAST(:path AS ltree) "
switch {
case query == "":
@@ -807,6 +836,7 @@ func (repo groupRepository) RetrieveAllParentGroups(ctx context.Context, domainI
query = query + " AND " + levelCondition
}
pm.Path = cGroup.Path
return repo.retrieveGroups(ctx, domainID, userID, query, pm)
}
@@ -822,19 +852,19 @@ func (repo groupRepository) RetrieveChildrenGroups(ctx context.Context, domainID
switch {
// Retrieve all children groups from parent group level
case startLevel == 0 && endLevel < 0:
levelCondition = fmt.Sprintf(" path ~ '%s.*'::::lquery ", pGroup.Path)
levelCondition = " path ~ CAST(:path || '.*' AS lquery) "
// Retrieve specific level of children groups from parent group level
case (startLevel > 0) && (startLevel == endLevel || endLevel == 0):
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d}'::::lquery ", pGroup.Path, startLevel)
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d}' AS lquery) ", startLevel)
// Retrieve all children groups from specific level from parent group level
case startLevel > 0 && endLevel < 0:
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d,}'::::lquery ", pGroup.Path, startLevel)
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d,}' AS lquery) ", startLevel)
// Retrieve children groups between specific level from parent group level
case startLevel > 0 && endLevel > 0 && startLevel < endLevel:
levelCondition = fmt.Sprintf(" path ~ '%s.*{%d,%d}'::::lquery ", pGroup.Path, startLevel, endLevel)
levelCondition = fmt.Sprintf(" path ~ CAST(:path || '.*{%d,%d}' AS lquery) ", startLevel, endLevel)
default:
return groups.Page{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid level range: start level: %d end level: %d", startLevel, endLevel))
}
@@ -846,6 +876,7 @@ func (repo groupRepository) RetrieveChildrenGroups(ctx context.Context, domainID
query = query + " AND " + levelCondition
}
pm.Path = pGroup.Path
return repo.retrieveGroups(ctx, domainID, userID, query, pm)
}
@@ -868,7 +899,7 @@ func (repo groupRepository) RetrieveUserGroups(ctx context.Context, domainID, us
}
func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID, query string, pm groups.PageMeta) (groups.Page, error) {
baseQuery := repo.userGroupsBaseQuery(domainID, userID)
baseQuery := userGroupsBaseQuery
orderClause := ""
var orderBy string
@@ -889,6 +920,30 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
orderClause = fmt.Sprintf("ORDER BY %s %s, g.id %s", orderBy, dir, dir)
}
dbPageMeta, err := toDBGroupPageMeta(pm)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
dbPageMeta.UserID = userID
dbPageMeta.DomainIDParam = domainID
if pm.OnlyTotal {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_groups g
%s;
`, baseQuery, query)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
page := groups.Page{PageMeta: pm}
page.Total = total
return page, nil
}
q := fmt.Sprintf(`%s
SELECT
g.id,
@@ -910,45 +965,49 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
g.access_provider_id,
g.access_provider_role_id,
g.access_provider_role_name,
g.access_provider_role_actions
g.access_provider_role_actions,
COUNT(*) OVER() AS total_count
FROM final_groups g
%s
%s
LIMIT :limit OFFSET :offset;`,
baseQuery, query, orderClause)
dbPageMeta, err := toDBGroupPageMeta(pm)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
var items []groups.Group
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err = repo.processRows(rows)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM (
SELECT g.id
FROM final_groups g
%s
) AS subquery;`,
baseQuery, query)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
var total uint64
var items []groups.Group
for rows.Next() {
dbg := dbGroup{}
if err := rows.StructScan(&dbg); err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
total = dbg.TotalCount
group, err := toGroup(dbg)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
items = append(items, group)
}
if len(items) == 0 {
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM final_groups g
%s;
`, baseQuery, query)
total, err = postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
page := groups.Page{PageMeta: pm}
page.Total = total
@@ -956,8 +1015,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
return page, nil
}
func (repo groupRepository) userGroupsBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
const userGroupsBaseQuery = `
WITH direct_groups AS (
SELECT
g.*,
@@ -975,8 +1033,8 @@ JOIN
JOIN
"groups" g ON g.id = gr.entity_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
@@ -997,12 +1055,12 @@ direct_groups_with_subgroup AS (
JOIN
"groups" g ON g.id = gr.entity_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
grm.member_id = :user_id
AND g.domain_id = :domain_id_param
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
HAVING
bool_or(gra."action" LIKE 'subgroup_%%')
bool_or(gra."action" LIKE 'subgroup_%')
),
direct_leaf_groups_with_subgroup AS (
SELECT dgws.*
@@ -1026,11 +1084,10 @@ indirect_child_groups AS (
FROM
direct_leaf_groups_with_subgroup dlgws
JOIN
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path -- Finds all children of entity_id based on ltree path
groups indirect_child_groups ON indirect_child_groups.path <@ dlgws.path
WHERE
indirect_child_groups.domain_id = '%s'
AND
NOT EXISTS ( -- Ensures that the indirect_child_groups.id is not already in the direct_groups_with_subgroup table
indirect_child_groups.domain_id = :domain_id_param
AND NOT EXISTS (
SELECT 1
FROM direct_groups_with_subgroup dgws
WHERE dgws.id = indirect_child_groups.id
@@ -1057,7 +1114,7 @@ direct_indirect_groups as (
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
CAST(array[] AS text[]) AS access_provider_role_actions
FROM
direct_groups
UNION
@@ -1076,7 +1133,7 @@ direct_indirect_groups as (
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'indirect' AS access_type,
access_provider_id,
access_provider_role_id,
@@ -1125,7 +1182,7 @@ final_groups AS (
dg."path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
CAST(array[] AS text[]) AS actions,
'domain' AS access_type,
d.id AS access_provider_id,
dr.id AS access_provider_role_id,
@@ -1142,24 +1199,23 @@ final_groups AS (
JOIN
"groups" dg ON dg.domain_id = d.id
WHERE
drm.member_id = '%s' -- user_id
AND d.id = '%s' -- domain_id
AND dra."action" LIKE 'group_%%'
AND NOT EXISTS ( -- Ensures that the direct and indirect groups are not in included.
drm.member_id = :user_id
AND d.id = :domain_id_param
AND dra."action" LIKE 'group_%'
AND NOT EXISTS (
SELECT 1 FROM direct_indirect_groups dig
WHERE dig.id = dg.id
)
GROUP BY
dg.id, d.id, dr.id
)
`, userID, domainID, userID, domainID, domainID, userID, domainID)
}
`
func buildQuery(gm groups.PageMeta, ids ...string) string {
queries := []string{}
if len(ids) > 0 {
queries = append(queries, fmt.Sprintf(" id in ('%s') ", strings.Join(ids, "', '")))
queries = append(queries, "id = ANY(:ids)")
}
if gm.Name != "" {
queries = append(queries, "g.name ILIKE '%' || :name || '%'")
@@ -1233,6 +1289,9 @@ type dbGroup struct {
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
TotalCount uint64 `db:"total_count"`
UserID string `db:"user_id,omitempty"`
DomainIDParam string `db:"domain_id_param,omitempty"`
}
func toDBGroup(g groups.Group) (dbGroup, error) {
@@ -1360,31 +1419,35 @@ func toDBGroupPageMeta(pm groups.PageMeta) (dbGroupPageMeta, error) {
RoleID: pm.RoleID,
Actions: pm.Actions,
AccessType: pm.AccessType,
Path: pm.Path,
CreatedFrom: pm.CreatedFrom,
CreatedTo: pm.CreatedTo,
}, nil
}
type dbGroupPageMeta struct {
ID string `db:"id"`
Name string `db:"name"`
ParentID string `db:"parent_id"`
DomainID string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Path string `db:"path"`
Level uint64 `db:"level"`
Total uint64 `db:"total"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Subject string `db:"subject"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
Status groups.Status `db:"status"`
Tags pgtype.TextArray `db:"tags"`
CreatedFrom time.Time `db:"created_from"`
CreatedTo time.Time `db:"created_to"`
ID string `db:"id"`
Name string `db:"name"`
ParentID string `db:"parent_id"`
DomainID string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Path string `db:"path"`
Level uint64 `db:"level"`
Total uint64 `db:"total"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Subject string `db:"subject"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
Status groups.Status `db:"status"`
Tags pgtype.TextArray `db:"tags"`
IDs pq.StringArray `db:"ids"`
CreatedFrom time.Time `db:"created_from"`
CreatedTo time.Time `db:"created_to"`
UserID string `db:"user_id"`
DomainIDParam string `db:"domain_id_param"`
}
func (repo groupRepository) processRows(rows *sqlx.Rows) ([]groups.Group, error) {
@@ -1403,23 +1466,23 @@ func (repo groupRepository) processRows(rows *sqlx.Rows) ([]groups.Group, error)
return items, nil
}
func (repo groupRepository) getInsertQuery(c context.Context, g groups.Group) (string, error) {
func (repo groupRepository) getInsertQuery(c context.Context, g groups.Group) (string, string, error) {
switch {
case g.Parent != "":
parent, err := repo.RetrieveByID(c, g.Parent)
if err != nil {
return "", err
return "", "", err
}
path := parent.Path + "." + g.ID
if len(strings.Split(path, ".")) > groups.MaxPathLength {
return "", fmt.Errorf("reached max nested depth")
return "", "", fmt.Errorf("reached max nested depth")
}
return fmt.Sprintf(`INSERT INTO groups (name, description, tags, id, domain_id, parent_id, metadata, created_at, status, path)
VALUES (:name, :description, :tags, :id, :domain_id, :parent_id, :metadata, :created_at, :status, '%s')
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path), nil
return `INSERT INTO groups (name, description, tags, id, domain_id, parent_id, metadata, created_at, status, path)
VALUES (:name, :description, :tags, :id, :domain_id, :parent_id, :metadata, :created_at, :status, CAST(:path AS ltree))
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path, nil
default:
return `INSERT INTO groups (name, description, tags, id, domain_id, metadata, created_at, status, path)
VALUES (:name, :description, :tags, :id, :domain_id, :metadata, :created_at, :status, :id)
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, nil
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, "", nil
}
}
+9
View File
@@ -95,6 +95,15 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
`SELECT 1`,
},
},
{
Id: "groups_07",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_groups_domain_id_status ON groups(domain_id, status);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_groups_domain_id_status;`,
},
},
},
}
+9
View File
@@ -63,6 +63,15 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE clients_telemetry ALTER COLUMN last_seen TYPE TIMESTAMP;`,
},
},
{
Id: "journal_03",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_journal_domain ON journal(domain);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_journal_domain;`,
},
},
},
}
}
+12 -10
View File
@@ -71,7 +71,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
if page.Direction == "" {
page.Direction = "ASC"
}
q := fmt.Sprintf("SELECT %s FROM journal %s ORDER BY occurred_at %s LIMIT :limit OFFSET :offset;", sq, query, page.Direction)
q := fmt.Sprintf("SELECT %s, COUNT(*) OVER() AS total_count FROM journal %s ORDER BY occurred_at %s LIMIT :limit OFFSET :offset;", sq, query, page.Direction)
rows, err := repo.db.NamedQueryContext(ctx, q, page)
if err != nil {
@@ -79,12 +79,14 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
}
defer rows.Close()
var total uint64
var items []journal.Journal
for rows.Next() {
var item dbJournal
if err = rows.StructScan(&item); err != nil {
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = item.TotalCount
j, err := toJournal(item)
if err != nil {
return journal.JournalsPage{}, err
@@ -92,21 +94,20 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
items = append(items, j)
}
tq := fmt.Sprintf(`SELECT COUNT(*) FROM journal %s;`, query)
total, err := postgres.Total(ctx, repo.db, tq, page)
if err != nil {
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
if len(items) == 0 {
tq := fmt.Sprintf(`SELECT COUNT(*) FROM journal %s;`, query)
total, err = postgres.Total(ctx, repo.db, tq, page)
if err != nil {
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
journalsPage := journal.JournalsPage{
return journal.JournalsPage{
Total: total,
Offset: page.Offset,
Limit: page.Limit,
Journals: items,
}
return journalsPage, nil
}, nil
}
func pageQuery(pm journal.Page) string {
@@ -139,6 +140,7 @@ type dbJournal struct {
OccurredAt time.Time `db:"occurred_at"`
Attributes []byte `db:"attributes"`
Metadata []byte `db:"metadata"`
TotalCount uint64 `db:"total_count"`
}
func toDBJournal(j journal.Journal) (dbJournal, error) {
+11
View File
@@ -66,6 +66,17 @@ func Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName string)
fmt.Sprintf(`ALTER TABLE %s_roles ALTER COLUMN updated_at TYPE TIMESTAMP;`, rolesTableNamePrefix),
},
},
{
Id: fmt.Sprintf("%s_roles_3", rolesTableNamePrefix),
Up: []string{
fmt.Sprintf(`CREATE INDEX IF NOT EXISTS idx_%s_role_members_member_id ON %s_role_members(member_id);`, rolesTableNamePrefix, rolesTableNamePrefix),
fmt.Sprintf(`CREATE INDEX IF NOT EXISTS idx_%s_role_actions_action ON %s_role_actions(action text_pattern_ops);`, rolesTableNamePrefix, rolesTableNamePrefix),
},
Down: []string{
fmt.Sprintf(`DROP INDEX IF EXISTS idx_%s_role_members_member_id;`, rolesTableNamePrefix),
fmt.Sprintf(`DROP INDEX IF EXISTS idx_%s_role_actions_action;`, rolesTableNamePrefix),
},
},
},
}, nil
}
+5 -5
View File
@@ -694,7 +694,7 @@ func (repo *Repository) RoleRemoveAllMembers(ctx context.Context, role roles.Rol
dbrcap := dbRoleAction{RoleID: role.ID}
if _, err := repo.db.NamedExecContext(ctx, q, dbrcap); err != nil {
if _, err := tx.NamedExec(q, dbrcap); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
@@ -1113,7 +1113,7 @@ ungrouped_members AS (
ARRAY_AGG(DISTINCT cra."action") AS actions,
'direct' AS access_type,
'' AS access_provider_id,
''::::LTREE AS access_provider_path
CAST('' AS LTREE) AS access_provider_path
FROM
client_group cg
JOIN
@@ -1180,7 +1180,7 @@ ungrouped_members AS (
ARRAY_AGG(DISTINCT agg_dra."action") AS actions,
'domain' AS access_type,
d.id AS access_provider_id,
''::::LTREE AS access_provider_path
CAST('' AS LTREE) AS access_provider_path
FROM
client_group cg
JOIN
@@ -1251,7 +1251,7 @@ ungrouped_members AS (
ARRAY_AGG(DISTINCT cra."action") AS actions,
'direct' AS access_type,
'' AS access_provider_id,
''::::LTREE AS access_provider_path
CAST('' AS LTREE) AS access_provider_path
FROM
channel_group cg
JOIN
@@ -1318,7 +1318,7 @@ ungrouped_members AS (
ARRAY_AGG(DISTINCT agg_dra."action") AS actions,
'domain' AS access_type,
d.id AS access_provider_id,
''::::LTREE AS access_provider_path
CAST('' AS LTREE) AS access_provider_path
FROM
channel_group cg
JOIN
+76 -69
View File
@@ -18,6 +18,7 @@ import (
"github.com/absmach/supermq/pkg/postgres"
"github.com/absmach/supermq/users"
"github.com/jackc/pgtype"
"github.com/lib/pq"
)
type userRepo struct {
@@ -122,57 +123,64 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
}
squery := applyOrdering(query, pm)
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at
FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
dbPage, err := ToDBUsersPage(pm)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
var items []users.User
if !pm.OnlyTotal {
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
if pm.OnlyTotal {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrRetrieveAllUsers, err)
}
defer rows.Close()
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToUser(dbu)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
items = append(items, c)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return users.UsersPage{
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
}, nil
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
squery := applyOrdering(query, pm)
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at,
COUNT(*) OVER() AS total_count
FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrRetrieveAllUsers, err)
}
defer rows.Close()
var total uint64
var items []users.User
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = dbu.TotalCount
c, err := ToUser(dbu)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
items = append(items, c)
}
page := users.UsersPage{
Page: users.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
return users.UsersPage{
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
Users: items,
}
return page, nil
}, nil
}
func (repo *userRepo) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
@@ -317,10 +325,10 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
}
tq := query
query = applyOrdering(query, pm)
squery := applyOrdering(query, pm)
q := fmt.Sprintf(`SELECT u.id, u.username, u.metadata, u.first_name, u.last_name, u.created_at, u.updated_at FROM users u %s LIMIT :limit OFFSET :offset;`, query)
q := fmt.Sprintf(`SELECT u.id, u.username, u.metadata, u.first_name, u.last_name, u.created_at, u.updated_at,
COUNT(*) OVER() AS total_count FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
dbPage, err := ToDBUsersPage(pm)
if err != nil {
@@ -333,12 +341,14 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
}
defer rows.Close()
var total uint64
var items []users.User
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = dbu.TotalCount
c, err := ToUser(dbu)
if err != nil {
@@ -348,23 +358,18 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
items = append(items, c)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, tq)
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
page := users.UsersPage{
return users.UsersPage{
Users: items,
Page: users.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
}, nil
}
func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (users.UsersPage, error) {
@@ -380,7 +385,8 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
squery := applyOrdering(query, pm)
q := fmt.Sprintf(`SELECT u.id, u.username, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name,
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by,
COUNT(*) OVER() AS total_count FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
dbPage, err := ToDBUsersPage(pm)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
@@ -391,12 +397,14 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
}
defer rows.Close()
var total uint64
var items []users.User
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
total = dbu.TotalCount
c, err := ToUser(dbu)
if err != nil {
@@ -405,23 +413,19 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
items = append(items, c)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
if len(items) == 0 {
cq := fmt.Sprintf(`SELECT COUNT(*) FROM users u %s;`, query)
total, err = postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
}
page := users.UsersPage{
return users.UsersPage{
Users: items,
Page: users.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
Page: users.Page{Total: total, Offset: pm.Offset, Limit: pm.Limit},
}, nil
}
func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.User, error) {
@@ -498,6 +502,7 @@ type DBUser struct {
Email string `db:"email,omitempty"`
VerifiedAt sql.NullTime `db:"verified_at,omitempty"`
AuthProvider sql.NullString `db:"auth_provider,omitempty"`
TotalCount uint64 `db:"total_count"`
}
func toDBUser(u users.User) (DBUser, error) {
@@ -634,6 +639,7 @@ type DBUsersPage struct {
GroupID string `db:"group_id"`
Role users.Role `db:"role"`
Status users.Status `db:"status"`
IDs pq.StringArray `db:"ids"`
CreatedFrom time.Time `db:"created_from"`
CreatedTo time.Time `db:"created_to"`
}
@@ -662,6 +668,7 @@ func ToDBUsersPage(pm users.Page) (DBUsersPage, error) {
Status: pm.Status,
Tags: tags,
Role: pm.Role,
IDs: pq.StringArray(pm.IDs),
CreatedFrom: pm.CreatedFrom,
CreatedTo: pm.CreatedTo,
}, nil
@@ -699,7 +706,7 @@ func PageQuery(pm users.Page) (string, error) {
query = append(query, "metadata @> :metadata")
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
query = append(query, "id = ANY(:ids)")
}
if pm.Status != users.AllStatus {
query = append(query, "u.status = :status")