SMQ-2605 - Groups replication with groups events consumer & listing of things and channels (#2639)

Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Arvindh
2025-01-20 17:06:50 +05:30
committed by GitHub
parent a8b12e4461
commit 88d583bfb1
66 changed files with 2719 additions and 1651 deletions
+1
View File
@@ -58,6 +58,7 @@ const (
UserKey = "user"
DomainKey = "domain"
ChannelKey = "channel"
ConnTypeKey = "connection_type"
DefPermission = "read_permission"
DefTotal = uint64(100)
DefOffset = 0
+90 -42
View File
@@ -11,7 +11,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
smqclients "github.com/absmach/supermq/clients"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
)
@@ -51,58 +51,106 @@ func decodeCreateChannelsReq(_ context.Context, r *http.Request) (interface{}, e
}
func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error) {
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus)
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
t, err := apiutil.ReadStringQuery(r, api.TagKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission)
tag, err := apiutil.ReadStringQuery(r, api.TagKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms)
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
st, err := smqclients.ToStatus(s)
status, err := clients.ToStatus(s)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
actions := []string{}
allActions = strings.TrimSpace(allActions)
if allActions != "" {
actions = strings.Split(allActions, ",")
}
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req := listChannelsReq{
status: st,
offset: o,
limit: l,
metadata: m,
name: n,
tag: t,
permission: p,
listPerms: lp,
userID: chi.URLParam(r, "userID"),
id: id,
name: name,
tag: tag,
status: status,
metadata: meta,
roleName: roleName,
roleID: roleID,
actions: actions,
accessType: accessType,
order: order,
dir: dir,
offset: offset,
limit: limit,
groupID: groupID,
clientID: clientID,
userID: userID,
}
return req, nil
}
+15 -9
View File
@@ -104,15 +104,21 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
}
pm := channels.PageMetadata{
Status: req.status,
Offset: req.offset,
Limit: req.limit,
Name: req.name,
Tag: req.tag,
Permission: req.permission,
Metadata: req.metadata,
ListPerms: req.listPerms,
Id: req.id,
Offset: req.offset,
Limit: req.limit,
Name: req.name,
Order: req.order,
Dir: req.dir,
Metadata: req.metadata,
Tag: req.tag,
Status: req.status,
Group: req.groupID,
Client: req.clientID,
ConnectionType: req.connType,
RoleName: req.roleName,
RoleID: req.roleID,
Actions: req.actions,
AccessType: req.accessType,
}
page, err := svc.ListChannels(ctx, session, pm)
if err != nil {
+15 -15
View File
@@ -9,7 +9,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
smqclients "github.com/absmach/supermq/clients"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/connections"
)
@@ -64,29 +64,29 @@ func (req viewChannelReq) validate() error {
}
type listChannelsReq struct {
status smqclients.Status
offset uint64
limit uint64
name string
tag string
permission string
visibility string
status clients.Status
metadata clients.Metadata
roleName string
roleID string
actions []string
accessType string
order string
dir string
offset uint64
limit uint64
groupID string
clientID string
connType string
userID string
listPerms bool
metadata smqclients.Metadata
id string
}
func (req listChannelsReq) validate() error {
if req.limit > api.MaxLimitSize || req.limit < 1 {
return apiutil.ErrLimitSize
}
if req.visibility != "" &&
req.visibility != api.AllVisibility &&
req.visibility != api.MyVisibility &&
req.visibility != api.SharedVisibility {
return apiutil.ErrInvalidVisibilityType
}
if len(req.name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
-8
View File
@@ -174,14 +174,6 @@ func TestListChannelsReqValidation(t *testing.T) {
},
err: apiutil.ErrNameSize,
},
{
desc: "invalid visibility",
req: listChannelsReq{
limit: 10,
visibility: "invalid",
},
err: apiutil.ErrInvalidVisibilityType,
},
}
for _, tc := range cases {
err := tc.req.validate()
+41 -27
View File
@@ -22,29 +22,43 @@ type Channel struct {
ParentGroup string `json:"parent_group_id,omitempty"`
Domain string `json:"domain_id,omitempty"`
Metadata clients.Metadata `json:"metadata,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
Permissions []string `json:"permissions,omitempty"` // 1 for enabled, 0 for disabled
Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
// Extended
ParentGroupPath string `json:"parent_group_path"`
RoleID string `json:"role_id"`
RoleName string `json:"role_name"`
Actions []string `json:"actions"`
AccessType string `json:"access_type"`
AccessProviderId string `json:"access_provider_id"`
AccessProviderRoleId string `json:"access_provider_role_id"`
AccessProviderRoleName string `json:"access_provider_role_name"`
AccessProviderRoleActions []string `json:"access_provider_role_actions"`
}
type PageMetadata struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Metadata clients.Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Permission string `json:"permission,omitempty"`
Status clients.Status `json:"status,omitempty"`
IDs []string `json:"ids,omitempty"`
ListPerms bool `json:"-"`
ClientID string `json:"-"`
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata clients.Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status clients.Status `json:"status,omitempty"`
Group string `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
}
// ChannelsPage contains page related metadata as well as list of channels that
@@ -71,15 +85,15 @@ type AuthzReq struct {
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// CreateChannels adds channels to the user identified by the provided key.
// CreateChannels adds channels to the user.
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error)
// ViewChannel retrieves data about the channel identified by the provided
// ID, that belongs to the user identified by the provided key.
// ID, that belongs to the user.
ViewChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
// UpdateChannel updates the channel identified by the provided ID, that
// belongs to the user identified by the provided key.
// belongs to the user.
UpdateChannel(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
// UpdateChannelTags updates the channel's tags.
@@ -89,17 +103,14 @@ type Service interface {
DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
// ListChannels retrieves data about subset of channels that belongs to the
// user identified by the provided key.
// ListChannels retrieves data about subset of channels that belongs to the user.
ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error)
// ListChannelsByClient retrieves data about subset of channels that have
// specified client connected or not connected to them and belong to the user identified by
// the provided key.
ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm PageMetadata) (Page, error)
// ListUserChannels retrieves data about subset of channels that belong to the specified user.
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error)
// RemoveChannel removes the client identified by the provided ID, that
// belongs to the user identified by the provided key.
// belongs to the user.
RemoveChannel(ctx context.Context, session authn.Session, id string) error
// Connect adds clients to the channels list of connected clients.
@@ -131,6 +142,9 @@ type Repository interface {
ChangeStatus(ctx context.Context, channel Channel) (Channel, error)
// RetrieveUserChannels retrieves the channel of given domainID and userID.
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm PageMetadata) (Page, error)
// RetrieveByID retrieves the channel having the provided identifier
RetrieveByID(ctx context.Context, id string) (Channel, error)
+27 -30
View File
@@ -206,9 +206,6 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) {
if lce.Tag != "" {
val["tag"] = lce.Tag
}
if lce.Permission != "" {
val["permission"] = lce.Permission
}
if lce.Status.String() != "" {
val["status"] = lce.Status.String()
}
@@ -219,48 +216,48 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) {
return val, nil
}
type listChannelByClientEvent struct {
clientID string
type listUserChannelsEvent struct {
userID string
channels.PageMetadata
authn.Session
}
func (lcte listChannelByClientEvent) Encode() (map[string]interface{}, error) {
func (luce listUserChannelsEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": channelList,
"client_id": lcte.clientID,
"total": lcte.Total,
"offset": lcte.Offset,
"limit": lcte.Limit,
"domain": lcte.DomainID,
"user_id": lcte.UserID,
"token_type": lcte.Type.String(),
"super_admin": lcte.SuperAdmin,
"req_user_id": luce.userID,
"total": luce.Total,
"offset": luce.Offset,
"limit": luce.Limit,
"domain": luce.DomainID,
"user_id": luce.UserID,
"token_type": luce.Type.String(),
"super_admin": luce.SuperAdmin,
}
if lcte.Name != "" {
val["name"] = lcte.Name
if luce.Name != "" {
val["name"] = luce.Name
}
if lcte.Order != "" {
val["order"] = lcte.Order
if luce.Order != "" {
val["order"] = luce.Order
}
if lcte.Dir != "" {
val["dir"] = lcte.Dir
if luce.Dir != "" {
val["dir"] = luce.Dir
}
if lcte.Metadata != nil {
val["metadata"] = lcte.Metadata
if luce.Metadata != nil {
val["metadata"] = luce.Metadata
}
if lcte.Tag != "" {
val["tag"] = lcte.Tag
if luce.Domain != "" {
val["domain"] = luce.Domain
}
if lcte.Permission != "" {
val["permission"] = lcte.Permission
if luce.Tag != "" {
val["tag"] = luce.Tag
}
if lcte.Status.String() != "" {
val["status"] = lcte.Status.String()
if luce.Status.String() != "" {
val["status"] = luce.Status.String()
}
if len(lcte.IDs) > 0 {
val["ids"] = lcte.IDs
if len(luce.IDs) > 0 {
val["ids"] = luce.IDs
}
return val, nil
+4 -4
View File
@@ -126,13 +126,13 @@ func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, p
return cp, nil
}
func (es *eventStore) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
cp, err := es.svc.ListChannelsByClient(ctx, session, clientID, pm)
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
cp, err := es.svc.ListUserChannels(ctx, session, userID, pm)
if err != nil {
return cp, err
}
event := listChannelByClientEvent{
clientID: clientID,
event := listUserChannelsEvent{
userID: userID,
PageMetadata: pm,
Session: session,
}
+7 -3
View File
@@ -22,6 +22,7 @@ import (
var (
errView = errors.New("not authorized to view channel")
errList = errors.New("not authorized to list user channels")
errUpdate = errors.New("not authorized to update channel")
errUpdateTags = errors.New("not authorized to update channel tags")
errEnable = errors.New("not authorized to enable channel")
@@ -164,13 +165,13 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
}
}
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
session.SuperAdmin = true
}
return am.svc.ListChannels(ctx, session, pm)
}
func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
@@ -184,7 +185,10 @@ func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, ses
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
return am.svc.ListChannelsByClient(ctx, session, clientID, pm)
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
return channels.Page{}, errors.Wrap(err, errList)
}
return am.svc.ListUserChannels(ctx, session, userID, pm)
}
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
+5 -5
View File
@@ -82,11 +82,11 @@ func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Ses
return lm.svc.ListChannels(ctx, session, pm)
}
func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (cp channels.Page, err error) {
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (cp channels.Page, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("client_id", clientID),
slog.String("user_id", userID),
slog.Group("page",
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
@@ -95,12 +95,12 @@ func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session a
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List channels by client failed", args...)
lm.logger.Warn("List user channels failed", args...)
return
}
lm.logger.Info("List channels by client completed successfully", args...)
lm.logger.Info("List user channels completed successfully", args...)
}(time.Now())
return lm.svc.ListChannelsByClient(ctx, session, clientID, pm)
return lm.svc.ListUserChannels(ctx, session, userID, pm)
}
func (lm *loggingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
+4 -4
View File
@@ -58,12 +58,12 @@ func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Ses
return ms.svc.ListChannels(ctx, session, pm)
}
func (ms *metricsMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_channels_by_client").Add(1)
ms.latency.With("method", "list_channels_by_client").Observe(time.Since(begin).Seconds())
ms.counter.With("method", "list_user_channels").Add(1)
ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListChannelsByClient(ctx, session, clientID, pm)
return ms.svc.ListUserChannels(ctx, session, userID, pm)
}
func (ms *metricsMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
+28
View File
@@ -529,6 +529,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro
return r0, r1
}
// RetrieveUserChannels provides a mock function with given fields: ctx, domainID, userID, pm
func (_m *Repository) RetrieveUserChannels(ctx context.Context, domainID string, userID string, pm channels.PageMetadata) (channels.Page, error) {
ret := _m.Called(ctx, domainID, userID, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveUserChannels")
}
var r0 channels.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) (channels.Page, error)); ok {
return rf(ctx, domainID, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) channels.Page); ok {
r0 = rf(ctx, domainID, userID, pm)
} else {
r0 = ret.Get(0).(channels.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, channels.PageMetadata) error); ok {
r1 = rf(ctx, domainID, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RoleAddActions provides a mock function with given fields: ctx, role, actions
func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) {
ret := _m.Called(ctx, role, actions)
+7 -7
View File
@@ -246,27 +246,27 @@ func (_m *Service) ListChannels(ctx context.Context, session authn.Session, pm c
return r0, r1
}
// ListChannelsByClient provides a mock function with given fields: ctx, session, id, pm
func (_m *Service) ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm channels.PageMetadata) (channels.Page, error) {
ret := _m.Called(ctx, session, id, pm)
// ListUserChannels provides a mock function with given fields: ctx, session, userID, pm
func (_m *Service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
ret := _m.Called(ctx, session, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListChannelsByClient")
panic("no return value specified for ListUserChannels")
}
var r0 channels.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) (channels.Page, error)); ok {
return rf(ctx, session, id, pm)
return rf(ctx, session, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) channels.Page); ok {
r0 = rf(ctx, session, id, pm)
r0 = rf(ctx, session, userID, pm)
} else {
r0 = ret.Get(0).(channels.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, channels.PageMetadata) error); ok {
r1 = rf(ctx, session, id, pm)
r1 = rf(ctx, session, userID, pm)
} else {
r1 = ret.Error(1)
}
+404 -50
View File
@@ -21,6 +21,7 @@ import (
"github.com/absmach/supermq/pkg/postgres"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/lib/pq"
)
const (
@@ -152,7 +153,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
query = applyOrdering(query, pm)
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s LIMIT :limit OFFSET :offset;`, query)
dbPage, err := toDBChannelsPage(pm)
if err != nil {
@@ -196,6 +197,303 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
return page, nil
}
func (repo *channelRepository) RetrieveUserChannels(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
return repo.retrieveClients(ctx, domainID, userID, pm)
}
func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
pageQuery, err := PageQuery(pm)
if err != nil {
return channels.Page{}, err
}
bq := repo.userChannelsBaseQuery(domainID, userID)
connJoinQuery := ""
if pm.Client != "" {
connJoinQuery = "JOIN connection conn ON conn.channel_id = c.id"
}
q := fmt.Sprintf(`
%s
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.tags,
c.metadata,
c.created_by,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
FROM
final_channels c
%s
%s
`, bq, connJoinQuery, pageQuery)
q = applyOrdering(q, pm)
dbPage, err := toDBChannelsPage(pm)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var items []channels.Channel
for rows.Next() {
dbc := dbChannel{}
if err := rows.StructScan(&dbc); err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := toChannel(dbc)
if err != nil {
return channels.Page{}, err
}
items = append(items, c)
}
chJoinQuery := ""
if pm.Client != "" {
chJoinQuery = "JOIN connection conn ON conn.channel_id = c.id"
}
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.tags,
c.metadata,
c.created_by,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
FROM
final_channels c
%s
%s
) AS subquery;
`, bq, chJoinQuery, pageQuery)
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := channels.Page{
Channels: items,
PageMetadata: channels.PageMetadata{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
}
func (repo *channelRepository) userChannelsBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
WITH direct_channels AS (
select
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.tags,
c.metadata,
c.created_by,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
text2ltree('') as parent_group_path,
cr.id AS role_id,
cr."name" AS role_name,
array_agg(cra."action") AS actions,
'direct' as access_type,
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
FROM
channels_role_members crm
JOIN
channels_role_actions cra ON cra.role_id = crm.role_id
JOIN
channels_roles cr ON cr.id = crm.role_id
JOIN
channels c ON c.id = cr.entity_id
WHERE
crm.member_id = '%s'
AND c.domain_id = '%s'
GROUP BY
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
),
direct_groups AS (
SELECT
g.*,
gr.entity_id AS entity_id,
grm.member_id AS member_id,
gr.id AS role_id,
gr."name" AS role_name,
array_agg(gra."action") AS actions
FROM
groups_role_members grm
JOIN
groups_role_actions gra ON gra.role_id = grm.role_id
JOIN
groups_roles gr ON gr.id = grm.role_id
JOIN
"groups" g ON g.id = gr.entity_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
direct_groups_with_subgroup AS (
SELECT
*
FROM direct_groups
WHERE EXISTS (
SELECT 1
FROM unnest(direct_groups.actions) AS action
WHERE action LIKE 'subgroup_%%'
)
),
indirect_child_groups AS (
SELECT
DISTINCT indirect_child_groups.id as child_id,
indirect_child_groups.*,
dgws.id as access_provider_id,
dgws.role_id as access_provider_role_id,
dgws.role_name as access_provider_role_name,
dgws.actions as access_provider_role_actions
FROM
direct_groups_with_subgroup dgws
JOIN
groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path
WHERE
indirect_child_groups.domain_id = '%s'
AND NOT EXISTS (
SELECT 1
FROM direct_groups_with_subgroup dgws
WHERE dgws.id = indirect_child_groups.id
)
),
final_groups AS (
SELECT
id,
parent_id,
domain_id,
"name",
description,
metadata,
created_at,
updated_at,
updated_by,
status,
"path",
role_id,
role_name,
actions,
'direct_group' AS access_type,
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
FROM
direct_groups
UNION
SELECT
id,
parent_id,
domain_id,
"name",
description,
metadata,
created_at,
updated_at,
updated_by,
status,
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
'indirect_group' AS access_type,
access_provider_id,
access_provider_role_id,
access_provider_role_name,
access_provider_role_actions
FROM
indirect_child_groups
),
final_channels AS (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.tags,
c.metadata,
c.created_by,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
g.path AS parent_group_path,
g.role_id,
g.role_name,
g.actions,
g.access_type,
g.access_provider_id,
g.access_provider_role_id,
g.access_provider_role_name,
g.access_provider_role_actions
FROM
final_groups g
JOIN
channels c ON c.parent_group_id = g.id
WHERE
c.id NOT IN (SELECT id FROM direct_channels)
UNION
SELECT * FROM direct_channels
)
`, userID, domainID, userID, domainID, domainID)
}
func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
q := "DELETE FROM channels AS c WHERE c.id = ANY(:channel_ids) ;"
params := map[string]interface{}{
@@ -361,7 +659,7 @@ func (cr *channelRepository) RemoveChannelConnections(ctx context.Context, chann
func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, parentGroupID string) ([]channels.Channel, error) {
query := `SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;`
c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;`
rows, err := cr.db.NamedQueryContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)})
if err != nil {
@@ -420,17 +718,26 @@ func (cr *channelRepository) update(ctx context.Context, ch channels.Channel, qu
}
type dbChannel struct {
ID string `db:"id"`
Name string `db:"name,omitempty"`
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata,omitempty"`
CreatedAt time.Time `db:"created_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status clients.Status `db:"status,omitempty"`
Role *clients.Role `db:"role,omitempty"`
ID string `db:"id"`
Name string `db:"name,omitempty"`
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata,omitempty"`
CreatedBy *string `db:"created_by,omitempty"`
CreatedAt time.Time `db:"created_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status clients.Status `db:"status,omitempty"`
ParentGroupPath string `db:"parent_group_path,omitempty"`
RoleID string `db:"role_id,omitempty"`
RoleName string `db:"role_name,omitempty"`
Actions pq.StringArray `db:"actions,omitempty"`
AccessType string `db:"access_type,omitempty"`
AccessProviderId string `db:"access_provider_id,omitempty"`
AccessProviderRoleId string `db:"access_provider_role_id,omitempty"`
AccessProviderRoleName string `db:"access_provider_role_name,omitempty"`
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"`
}
func toDBChannel(ch channels.Channel) (dbChannel, error) {
@@ -446,6 +753,10 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) {
if err := tags.Set(ch.Tags); err != nil {
return dbChannel{}, err
}
var createdBy *string
if ch.CreatedBy != "" {
createdBy = &ch.CreatedBy
}
var updatedBy *string
if ch.UpdatedBy != "" {
updatedBy = &ch.UpdatedBy
@@ -461,6 +772,7 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) {
Domain: ch.Domain,
Tags: tags,
Metadata: data,
CreatedBy: createdBy,
CreatedAt: ch.CreatedAt,
UpdatedAt: updatedAt,
UpdatedBy: updatedBy,
@@ -497,6 +809,10 @@ func toChannel(ch dbChannel) (channels.Channel, error) {
for _, e := range ch.Tags.Elements {
tags = append(tags, e.String)
}
var createdBy string
if ch.CreatedBy != nil {
createdBy = *ch.CreatedBy
}
var updatedBy string
if ch.UpdatedBy != nil {
updatedBy = *ch.UpdatedBy
@@ -507,16 +823,26 @@ func toChannel(ch dbChannel) (channels.Channel, error) {
}
newCh := channels.Channel{
ID: ch.ID,
Name: ch.Name,
Tags: tags,
Domain: ch.Domain,
ParentGroup: toString(ch.ParentGroup),
Metadata: metadata,
CreatedAt: ch.CreatedAt,
UpdatedAt: updatedAt,
UpdatedBy: updatedBy,
Status: ch.Status,
ID: ch.ID,
Name: ch.Name,
Tags: tags,
Domain: ch.Domain,
ParentGroup: toString(ch.ParentGroup),
Metadata: metadata,
CreatedBy: createdBy,
CreatedAt: ch.CreatedAt,
UpdatedAt: updatedAt,
UpdatedBy: updatedBy,
Status: ch.Status,
ParentGroupPath: ch.ParentGroupPath,
RoleID: ch.RoleID,
RoleName: ch.RoleName,
Actions: ch.Actions,
AccessType: ch.AccessType,
AccessProviderId: ch.AccessProviderId,
AccessProviderRoleId: ch.AccessProviderRoleId,
AccessProviderRoleName: ch.AccessProviderRoleName,
AccessProviderRoleActions: ch.AccessProviderRoleActions,
}
return newCh, nil
@@ -533,9 +859,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
query = append(query, "c.name ILIKE '%' || :name || '%'")
}
if pm.ClientID != "" {
query = append(query, "conn.client_id = :client_id")
}
if pm.Id != "" {
query = append(query, "c.id ILIKE '%' || :id || '%'")
}
@@ -543,12 +866,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
}
// If there are search params presents, use search and ignore other options.
// Always combine role with search params, so len(query) > 1.
if len(query) > 1 {
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil
}
if mq != "" {
query = append(query, mq)
}
@@ -562,6 +879,31 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
if pm.Domain != "" {
query = append(query, "c.domain_id = :domain_id")
}
if pm.Group != "" {
query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ")
}
if pm.Client != "" {
query = append(query, "conn.client_id = :client_id ")
if pm.ConnectionType != "" {
query = append(query, "conn.type = :conn_type ")
}
}
if pm.AccessType != "" {
query = append(query, "c.access_type = :access_type")
}
if pm.RoleID != "" {
query = append(query, "c.role_id = :role_id")
}
if pm.RoleName != "" {
query = append(query, "c.role_name = :role_name")
}
if len(pm.Actions) != 0 {
query = append(query, "c.actions @> :actions")
}
if len(pm.Metadata) > 0 {
query = append(query, "c.metadata @> :metadata")
}
var emq string
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
@@ -586,28 +928,40 @@ func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) {
return dbChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return dbChannelsPage{
Name: pm.Name,
Id: pm.Id,
Metadata: data,
Domain: pm.Domain,
Total: pm.Total,
Offset: pm.Offset,
Limit: pm.Limit,
Status: pm.Status,
Tag: pm.Tag,
Limit: pm.Limit,
Offset: pm.Offset,
Name: pm.Name,
Id: pm.Id,
Domain: pm.Domain,
Metadata: data,
Tag: pm.Tag,
Status: pm.Status,
GroupID: pm.Group,
ClientID: pm.Client,
ConnType: pm.ConnectionType,
RoleName: pm.RoleName,
RoleID: pm.RoleID,
Actions: pm.Actions,
AccessType: pm.AccessType,
}, nil
}
type dbChannelsPage struct {
Total uint64 `db:"total"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status clients.Status `db:"status"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status clients.Status `db:"status"`
GroupID string `db:"group_id"`
ClientID string `db:"client_id"`
ConnType string `db:"type"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
}
type dbConnection struct {
+9
View File
@@ -4,6 +4,7 @@
package postgres
import (
gpostgres "github.com/absmach/supermq/groups/postgres"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
@@ -56,5 +57,13 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
},
}
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
groupsMigration, err := gpostgres.Migration()
if err != nil {
return &migrate.MemoryMigrationSource{}, err
}
channelsMigration.Migrations = append(channelsMigration.Migrations, groupsMigration.Migrations...)
return channelsMigration, nil
}
+1 -1
View File
@@ -141,7 +141,7 @@ const (
// External Permission
// Domains.
domainCreateChannelPermission = "channel_create_permission"
domainListChanelPermission = "list_channels_permission"
domainListChanelPermission = "channel_read_permission"
// Groups.
groupSetChildChannelPermission = "channel_create_permission"
groupRemoveChildChannelPermission = "channel_create_permission"
+12 -68
View File
@@ -21,7 +21,6 @@ import (
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
"golang.org/x/sync/errgroup"
)
var (
@@ -183,50 +182,30 @@ func (svc service) ViewChannel(ctx context.Context, session authn.Session, id st
}
func (svc service) ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error) {
var ids []string
var err error
switch session.SuperAdmin {
case true:
pm.Domain = session.DomainID
default:
ids, err = svc.listChannelIDs(ctx, session.DomainUserID, pm.Permission)
cp, err := svc.repo.RetrieveAll(ctx, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrNotFound, err)
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
default:
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, session.UserID, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
}
if len(ids) == 0 && pm.Domain == "" {
return Page{}, nil
}
pm.IDs = ids
}
cp, err := svc.repo.RetrieveAll(ctx, pm)
func (svc service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error) {
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, userID, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if pm.ListPerms && len(cp.Channels) > 0 {
g, ctx := errgroup.WithContext(ctx)
for i := range cp.Channels {
// Copying loop variable "i" to avoid "loop variable captured by func literal"
iter := i
g.Go(func() error {
return svc.retrievePermissions(ctx, session.DomainUserID, &cp.Channels[iter])
})
}
if err := g.Wait(); err != nil {
return Page{}, err
}
}
return cp, nil
}
func (svc service) ListChannelsByClient(ctx context.Context, session authn.Session, clID string, pm PageMetadata) (Page, error) {
return Page{}, nil
}
func (svc service) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
ok, err := svc.repo.DoesChannelHaveConnections(ctx, id)
if err != nil {
@@ -493,41 +472,6 @@ func (svc service) RemoveParentGroup(ctx context.Context, session authn.Session,
return nil
}
func (svc service) listChannelIDs(ctx context.Context, userID, permission string) ([]string, error) {
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: userID,
Permission: permission,
ObjectType: policies.ChannelType,
})
if err != nil {
return nil, errors.Wrap(svcerr.ErrNotFound, err)
}
return tids.Policies, nil
}
func (svc service) retrievePermissions(ctx context.Context, userID string, channel *Channel) error {
permissions, err := svc.listUserClientPermission(ctx, userID, channel.ID)
if err != nil {
return err
}
channel.Permissions = permissions
return nil
}
func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) {
lp, err := svc.policy.ListPermissions(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: userID,
Object: clientID,
ObjectType: policies.ChannelType,
}, []string{})
if err != nil {
return []string{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return lp, nil
}
func (svc service) changeChannelStatus(ctx context.Context, userID string, channel Channel) (Channel, error) {
dbchannel, err := svc.repo.RetrieveByID(ctx, channel.ID)
if err != nil {
+148 -187
View File
@@ -21,6 +21,7 @@ import (
gpmocks "github.com/absmach/supermq/groups/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
@@ -459,220 +460,180 @@ func TestDisableChannel(t *testing.T) {
func TestListChannels(t *testing.T) {
svc := newService(t)
channelWithPerms := validChannel
channelWithPerms.Permissions = []string{policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission}
adminID := testsutil.GenerateUUID(t)
domainID := testsutil.GenerateUUID(t)
nonAdminID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
session authn.Session
pageMeta channels.PageMetadata
listAllObjectsRes policysvc.PolicyPage
listAllObjectsErr error
retrieveAllRes channels.Page
retrieveAllErr error
listPermissionsRes policysvc.Permissions
listPermissionsErr error
resp channels.Page
err error
desc string
userKind string
session smqauthn.Session
page channels.PageMetadata
retrieveAllResponse channels.Page
response channels.Page
id string
size uint64
listObjectsErr error
retrieveAllErr error
listPermissionsErr error
err error
}{
{
desc: "list channesls as admin successfully",
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: channels.PageMetadata{
Domain: validID,
desc: "list all channels successfully as non admin",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
retrieveAllResponse: channels.Page{
PageMetadata: channels.PageMetadata{
Total: 1,
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
resp: channels.Page{
Channels: []channels.Channel{validChannel},
response: channels.Page{
PageMetadata: channels.PageMetadata{
Total: 1,
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
err: nil,
},
{
desc: "list channels as admin with list perms successfully",
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: channels.PageMetadata{
Domain: validID,
ListPerms: true,
desc: "list all channels as non admin with failed to retrieve all",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
},
listPermissionsRes: policysvc.Permissions{
policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission,
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
resp: channels.Page{
Channels: []channels.Channel{channelWithPerms},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
err: nil,
retrieveAllResponse: channels.Page{},
response: channels.Page{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list channels as admin with failed to retrieve all",
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: channels.PageMetadata{
Domain: validID,
desc: "list all channels as non admin with failed super admin",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
},
retrieveAllRes: channels.Page{},
retrieveAllErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "list channels as admin with failed to list permissions",
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: channels.PageMetadata{
Domain: validID,
ListPerms: true,
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
listPermissionsRes: policysvc.Permissions{},
listPermissionsErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "list channels as admin with no domain id",
session: authn.Session{UserID: validID, SuperAdmin: true},
pageMeta: channels.PageMetadata{},
response: channels.Page{},
err: nil,
},
{
desc: "list channels as user successfully",
session: validSession,
pageMeta: channels.PageMetadata{
Permission: policysvc.ViewPermission,
IDs: []string{validChannel.ID},
desc: "list all channels as non admin with failed to list objects",
userKind: "non-admin",
id: nonAdminID,
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
},
listAllObjectsRes: policysvc.PolicyPage{
Policies: []string{validChannel.ID},
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
resp: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
err: nil,
},
{
desc: "list channels as user with failed to list all objects",
session: validSession,
pageMeta: channels.PageMetadata{
Permission: policysvc.ViewPermission,
IDs: []string{validChannel.ID},
},
listAllObjectsErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "list channels as user with list permissions successfully",
session: validSession,
pageMeta: channels.PageMetadata{
Permission: policysvc.ViewPermission,
IDs: []string{validChannel.ID},
ListPerms: true,
},
listAllObjectsRes: policysvc.PolicyPage{
Policies: []string{validChannel.ID},
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
listPermissionsRes: policysvc.Permissions{
policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission,
},
resp: channels.Page{
Channels: []channels.Channel{channelWithPerms},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
err: nil,
},
{
desc: "list channels as user with list permissions and failed to list permissions",
session: validSession,
pageMeta: channels.PageMetadata{
Permission: policysvc.ViewPermission,
IDs: []string{validChannel.ID},
ListPerms: true,
},
listAllObjectsRes: policysvc.PolicyPage{
Policies: []string{validChannel.ID},
},
retrieveAllRes: channels.Page{
Channels: []channels.Channel{validChannel},
PageMetadata: channels.PageMetadata{
Total: 1,
},
},
listPermissionsRes: policysvc.Permissions{},
listPermissionsErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "list channels as user with failed to retrieve all",
session: validSession,
pageMeta: channels.PageMetadata{
Permission: policysvc.ViewPermission,
IDs: []string{validChannel.ID},
},
listAllObjectsRes: policysvc.PolicyPage{
Policies: []string{validChannel.ID},
},
retrieveAllRes: channels.Page{},
retrieveAllErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
response: channels.Page{},
listObjectsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
policyCall := policies.On("ListAllObjects", context.Background(), policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: validID,
Permission: policysvc.ViewPermission,
ObjectType: policysvc.ChannelType,
}).Return(tc.listAllObjectsRes, tc.listAllObjectsErr)
repoCall := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.retrieveAllRes, tc.retrieveAllErr)
policyCall1 := policies.On("ListPermissions", mock.Anything, policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: validID,
Object: validChannel.ID,
ObjectType: policysvc.ChannelType,
}, []string{}).Return(tc.listPermissionsRes, tc.listPermissionsErr)
got, err := svc.ListChannels(context.Background(), tc.session, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
assert.Equal(t, tc.resp, got)
policyCall.Unset()
repoCall.Unset()
policyCall1.Unset()
})
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveUserClientsCall := repo.On("RetrieveUserChannels", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListChannels(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
retrieveAllCall.Unset()
retrieveUserClientsCall.Unset()
}
cases2 := []struct {
desc string
userKind string
session smqauthn.Session
page channels.PageMetadata
retrieveAllResponse channels.Page
response channels.Page
id string
size uint64
listObjectsErr error
retrieveAllErr error
listPermissionsErr error
err error
}{
{
desc: "list all clients as admin successfully",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{
PageMetadata: channels.PageMetadata{
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
err: nil,
},
{
desc: "list all clients as admin with failed to retrieve all",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list clients",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases2 {
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListChannels(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
retrieveAllCall.Unset()
}
}
+3 -3
View File
@@ -50,10 +50,10 @@ func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Ses
return tm.svc.ListChannels(ctx, session, pm)
}
func (tm *tracingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_channels")
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_user_channels")
defer span.End()
return tm.svc.ListChannelsByClient(ctx, session, clientID, pm)
return tm.svc.ListUserChannels(ctx, session, userID, pm)
}
// UpdateChannel traces the "UpdateChannel" operation of the wrapped policies.Service.
-26
View File
@@ -370,32 +370,6 @@ var cmdUsers = []cobra.Command{
logOKCmd(*cmd)
},
},
{
Use: "clients <user_id> <domain_id> <user_auth_token>",
Short: "List clients",
Long: "List clients of user\n" +
"Usage:\n" +
"\tsupermq-cli users clients <user_id> <user_auth_token>\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 {
logUsageCmd(*cmd, cmd.Use)
return
}
pm := smqsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
}
tp, err := sdk.ListUserClients(args[0], args[1], pm, args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, tp)
},
},
{
Use: "search <query> <user_auth_token>",
Short: "Search users",
+95 -41
View File
@@ -27,58 +27,112 @@ func decodeViewClient(_ context.Context, r *http.Request) (interface{}, error) {
}
func decodeListClients(_ context.Context, r *http.Request) (interface{}, error) {
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus)
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
t, err := apiutil.ReadStringQuery(r, api.TagKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission)
tag, err := apiutil.ReadStringQuery(r, api.TagKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms)
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
st, err := clients.ToStatus(s)
status, err := clients.ToStatus(s)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
actions := []string{}
allActions = strings.TrimSpace(allActions)
if allActions != "" {
actions = strings.Split(allActions, ",")
}
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
channelID, err := apiutil.ReadStringQuery(r, api.ChannelKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
connType, err := apiutil.ReadStringQuery(r, api.ConnTypeKey, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req := listClientsReq{
status: st,
offset: o,
limit: l,
metadata: m,
name: n,
tag: t,
permission: p,
listPerms: lp,
userID: chi.URLParam(r, "userID"),
id: id,
name: name,
tag: tag,
status: status,
metadata: meta,
roleName: roleName,
roleID: roleID,
actions: actions,
accessType: accessType,
order: order,
dir: dir,
offset: offset,
limit: limit,
groupID: groupID,
channelID: channelID,
connType: connType,
userID: userID,
}
return req, nil
}
+25 -11
View File
@@ -104,19 +104,33 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint {
}
pm := clients.Page{
Status: req.status,
Offset: req.offset,
Limit: req.limit,
Name: req.name,
Tag: req.tag,
Permission: req.permission,
Metadata: req.metadata,
ListPerms: req.listPerms,
Id: req.id,
Name: req.name,
Tag: req.tag,
Status: req.status,
Metadata: req.metadata,
RoleName: req.roleName,
RoleID: req.roleID,
Actions: req.actions,
AccessType: req.accessType,
Order: req.order,
Dir: req.dir,
Offset: req.offset,
Limit: req.limit,
Group: req.groupID,
Channel: req.channelID,
ConnectionType: req.connType,
}
var page clients.ClientsPage
var err error
switch req.userID != "" {
case true:
page, err = svc.ListUserClients(ctx, session, req.userID, pm)
default:
page, err = svc.ListClients(ctx, session, pm)
}
page, err := svc.ListClients(ctx, session, req.userID, pm)
if err != nil {
return nil, err
return clientsPageRes{}, err
}
res := clientsPageRes{
+1 -1
View File
@@ -743,7 +743,7 @@ func TestListClients(t *testing.T) {
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, "", mock.Anything).Return(tc.listClientsResponse, tc.err)
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, mock.Anything).Return(tc.listClientsResponse, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
+14 -14
View File
@@ -71,29 +71,29 @@ func (req viewClientPermsReq) validate() error {
}
type listClientsReq struct {
status clients.Status
offset uint64
limit uint64
name string
tag string
permission string
visibility string
userID string
listPerms bool
status clients.Status
metadata clients.Metadata
id string
roleName string
roleID string
actions []string
accessType string
order string
dir string
offset uint64
limit uint64
groupID string
channelID string
connType string
userID string
}
func (req listClientsReq) validate() error {
if req.limit > api.MaxLimitSize || req.limit < 1 {
return apiutil.ErrLimitSize
}
if req.visibility != "" &&
req.visibility != api.AllVisibility &&
req.visibility != api.MyVisibility &&
req.visibility != api.SharedVisibility {
return apiutil.ErrInvalidVisibilityType
}
if len(req.name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
-8
View File
@@ -210,14 +210,6 @@ func TestListClientsReqValidate(t *testing.T) {
},
err: apiutil.ErrLimitSize,
},
{
desc: "invalid visibility",
req: listClientsReq{
limit: 10,
visibility: "invalid",
},
err: apiutil.ErrInvalidVisibilityType,
},
{
desc: "name too long",
req: listClientsReq{
+44 -22
View File
@@ -13,6 +13,10 @@ import (
"github.com/absmach/supermq/pkg/roles"
)
type CtxKey int
const ListDomainClients CtxKey = iota
type Connection struct {
ClientID string
ChannelID string
@@ -35,11 +39,14 @@ type Repository interface {
// RetrieveAll retrieves all clients.
RetrieveAll(ctx context.Context, pm Page) (ClientsPage, error)
// RetrieveUserClients retrieve all clients of a given user id.
RetrieveUserClients(ctx context.Context, domainID, userID string, pm Page) (ClientsPage, error)
// SearchClients retrieves clients based on search criteria.
SearchClients(ctx context.Context, pm Page) (ClientsPage, error)
// RetrieveAllByIDs retrieves for given client IDs .
RetrieveAllByIDs(ctx context.Context, pm Page) (ClientsPage, error)
// RetrieveByIds
RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error)
// Update updates the client name and metadata.
Update(ctx context.Context, client Client) (Client, error)
@@ -66,8 +73,6 @@ type Repository interface {
// RetrieveBySecret retrieves a client based on the secret (key).
RetrieveBySecret(ctx context.Context, key string) (Client, error)
RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error)
AddConnections(ctx context.Context, conns []Connection) error
RemoveConnections(ctx context.Context, conns []Connection) error
@@ -105,8 +110,11 @@ type Service interface {
// View retrieves client info for a given client ID and an authorized token.
View(ctx context.Context, session authn.Session, id string) (Client, error)
// ListClients retrieves clients list for a valid auth token.
ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error)
// ListClients retrieves clients list for given page query.
ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error)
// ListUserClients retrieves clients list for a given user id and page query.
ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error)
// Update updates the client's name and metadata.
Update(ctx context.Context, session authn.Session, client Client) (Client, error)
@@ -161,8 +169,17 @@ type Client struct {
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
Permissions []string `json:"permissions,omitempty"`
Identity string `json:"identity,omitempty"`
// Extended
ParentGroupPath string `json:"parent_group_path,omitempty"`
RoleID string `json:"role_id,omitempty"`
RoleName string `json:"role_name,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
AccessProviderId string `json:"access_provider_id,omitempty"`
AccessProviderRoleId string `json:"access_provider_role_id,omitempty"`
AccessProviderRoleName string `json:"access_provider_role_name,omitempty"`
AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"`
}
// ClientsPage contains page related metadata as well as list.
@@ -182,21 +199,26 @@ type MembersPage struct {
// Page contains the page metadata that helps navigation.
type Page struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Name string `json:"name,omitempty"`
Id string `json:"id,omitempty"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Permission string `json:"permission,omitempty"`
Status Status `json:"status,omitempty"`
IDs []string `json:"ids,omitempty"`
Identity string `json:"identity,omitempty"`
ListPerms bool `json:"-"`
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status Status `json:"status,omitempty"`
Identity string `json:"identity,omitempty"`
Group string `json:"group,omitempty"`
Channel string `json:"channel,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
}
// Metadata represents arbitrary JSON.
+45 -7
View File
@@ -205,7 +205,6 @@ func (vcpe viewClientPermsEvent) Encode() (map[string]interface{}, error) {
}
type listClientEvent struct {
reqUserID string
clients.Page
authn.Session
}
@@ -213,7 +212,6 @@ type listClientEvent struct {
func (lce listClientEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": clientList,
"reqUserID": lce.reqUserID,
"total": lce.Total,
"offset": lce.Offset,
"limit": lce.Limit,
@@ -238,8 +236,51 @@ func (lce listClientEvent) Encode() (map[string]interface{}, error) {
if lce.Tag != "" {
val["tag"] = lce.Tag
}
if lce.Permission != "" {
val["permission"] = lce.Permission
if lce.Status.String() != "" {
val["status"] = lce.Status.String()
}
if len(lce.IDs) > 0 {
val["ids"] = lce.IDs
}
if lce.Identity != "" {
val["identity"] = lce.Identity
}
return val, nil
}
type listUserClientEvent struct {
userID string
clients.Page
authn.Session
}
func (lce listUserClientEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": clientList,
"req_user_id": lce.userID,
"total": lce.Total,
"offset": lce.Offset,
"limit": lce.Limit,
"domain": lce.DomainID,
"user_id": lce.UserID,
"token_type": lce.Type.String(),
"super_admin": lce.SuperAdmin,
}
if lce.Name != "" {
val["name"] = lce.Name
}
if lce.Order != "" {
val["order"] = lce.Order
}
if lce.Dir != "" {
val["dir"] = lce.Dir
}
if lce.Metadata != nil {
val["metadata"] = lce.Metadata
}
if lce.Tag != "" {
val["tag"] = lce.Tag
}
if lce.Status.String() != "" {
val["status"] = lce.Status.String()
@@ -288,9 +329,6 @@ func (lcge listClientByGroupEvent) Encode() (map[string]interface{}, error) {
if lcge.Tag != "" {
val["tag"] = lcge.Tag
}
if lcge.Permission != "" {
val["permission"] = lcge.Permission
}
if lcge.Status.String() != "" {
val["status"] = lcge.Status.String()
}
+21 -5
View File
@@ -118,15 +118,31 @@ func (es *eventStore) View(ctx context.Context, session authn.Session, id string
return cli, nil
}
func (es *eventStore) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
cp, err := es.svc.ListClients(ctx, session, reqUserID, pm)
func (es *eventStore) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
cp, err := es.svc.ListClients(ctx, session, pm)
if err != nil {
return cp, err
}
event := listClientEvent{
reqUserID: reqUserID,
Page: pm,
Session: session,
pm,
session,
}
if err := es.Publish(ctx, event); err != nil {
return cp, err
}
return cp, nil
}
func (es *eventStore) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
cp, err := es.svc.ListUserClients(ctx, session, userID, pm)
if err != nil {
return cp, err
}
event := listUserClientEvent{
userID,
pm,
session,
}
if err := es.Publish(ctx, event); err != nil {
return cp, err
+25 -3
View File
@@ -129,7 +129,29 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
return am.svc.View(ctx, session, id)
}
func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
}); err != nil {
return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
session.SuperAdmin = true
}
return am.svc.ListClients(ctx, session, pm)
}
func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
@@ -145,10 +167,10 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
}
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
session.SuperAdmin = true
return clients.ClientsPage{}, err
}
return am.svc.ListClients(ctx, session, reqUserID, pm)
return am.svc.ListUserClients(ctx, session, userID, pm)
}
func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
+23 -3
View File
@@ -65,11 +65,10 @@ func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id
return lm.svc.View(ctx, session, id)
}
func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (cp clients.ClientsPage, err error) {
func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (cp clients.ClientsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", reqUserID),
slog.Group("page",
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
@@ -83,7 +82,28 @@ func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Sess
}
lm.logger.Info("List clients completed successfully", args...)
}(time.Now())
return lm.svc.ListClients(ctx, session, reqUserID, pm)
return lm.svc.ListClients(ctx, session, pm)
}
func (lm *loggingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (cp clients.ClientsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.Group("page",
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
slog.Uint64("total", cp.Total),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List clients failed", args...)
return
}
lm.logger.Info("List clients completed successfully", args...)
}(time.Now())
return lm.svc.ListUserClients(ctx, session, userID, pm)
}
func (lm *loggingMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (c clients.Client, err error) {
+10 -2
View File
@@ -49,12 +49,20 @@ func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id
return ms.svc.View(ctx, session, id)
}
func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_clients").Add(1)
ms.latency.With("method", "list_clients").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListClients(ctx, session, reqUserID, pm)
return ms.svc.ListClients(ctx, session, pm)
}
func (ms *metricsMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_user_clients").Add(1)
ms.latency.With("method", "list_user_clients").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListUserClients(ctx, session, userID, pm)
}
func (ms *metricsMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
+28 -28
View File
@@ -312,34 +312,6 @@ func (_m *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clients
return r0, r1
}
// RetrieveAllByIDs provides a mock function with given fields: ctx, pm
func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
ret := _m.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveAllByIDs")
}
var r0 clients.ClientsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, clients.Page) (clients.ClientsPage, error)); ok {
return rf(ctx, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, clients.Page) clients.ClientsPage); ok {
r0 = rf(ctx, pm)
} else {
r0 = ret.Get(0).(clients.ClientsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, clients.Page) error); ok {
r1 = rf(ctx, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveAllRoles provides a mock function with given fields: ctx, entityID, limit, offset
func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) {
ret := _m.Called(ctx, entityID, limit, offset)
@@ -577,6 +549,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro
return r0, r1
}
// RetrieveUserClients provides a mock function with given fields: ctx, domainID, userID, pm
func (_m *Repository) RetrieveUserClients(ctx context.Context, domainID string, userID string, pm clients.Page) (clients.ClientsPage, error) {
ret := _m.Called(ctx, domainID, userID, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveUserClients")
}
var r0 clients.ClientsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) (clients.ClientsPage, error)); ok {
return rf(ctx, domainID, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) clients.ClientsPage); ok {
r0 = rf(ctx, domainID, userID, pm)
} else {
r0 = ret.Get(0).(clients.ClientsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, clients.Page) error); ok {
r1 = rf(ctx, domainID, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RoleAddActions provides a mock function with given fields: ctx, role, actions
func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) {
ret := _m.Called(ctx, role, actions)
+34 -6
View File
@@ -198,27 +198,55 @@ func (_m *Service) ListAvailableActions(ctx context.Context, session authn.Sessi
return r0, r1
}
// ListClients provides a mock function with given fields: ctx, session, reqUserID, pm
func (_m *Service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
ret := _m.Called(ctx, session, reqUserID, pm)
// ListClients provides a mock function with given fields: ctx, session, pm
func (_m *Service) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
ret := _m.Called(ctx, session, pm)
if len(ret) == 0 {
panic("no return value specified for ListClients")
}
var r0 clients.ClientsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) (clients.ClientsPage, error)); ok {
return rf(ctx, session, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) clients.ClientsPage); ok {
r0 = rf(ctx, session, pm)
} else {
r0 = ret.Get(0).(clients.ClientsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, clients.Page) error); ok {
r1 = rf(ctx, session, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListUserClients provides a mock function with given fields: ctx, session, userID, pm
func (_m *Service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
ret := _m.Called(ctx, session, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListUserClients")
}
var r0 clients.ClientsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) (clients.ClientsPage, error)); ok {
return rf(ctx, session, reqUserID, pm)
return rf(ctx, session, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) clients.ClientsPage); ok {
r0 = rf(ctx, session, reqUserID, pm)
r0 = rf(ctx, session, userID, pm)
} else {
r0 = ret.Get(0).(clients.ClientsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, clients.Page) error); ok {
r1 = rf(ctx, session, reqUserID, pm)
r1 = rf(ctx, session, userID, pm)
} else {
r1 = ret.Error(1)
}
+441 -105
View File
@@ -20,6 +20,7 @@ import (
"github.com/absmach/supermq/pkg/postgres"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/lib/pq"
)
const (
@@ -247,6 +248,350 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
return page, nil
}
func (repo *clientRepo) RetrieveUserClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
return repo.retrieveClients(ctx, domainID, userID, pm)
}
func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
pageQuery, err := PageQuery(pm)
if err != nil {
return clients.ClientsPage{}, err
}
bq := repo.userClientBaseQuery(domainID, userID)
q := fmt.Sprintf(`
%s
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
FROM
final_clients c
%s
`, bq, pageQuery)
q = applyOrdering(q, pm)
dbPage, err := ToDBClientsPage(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var items []clients.Client
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
if err != nil {
return clients.ClientsPage{}, err
}
items = append(items, c)
}
connJoinQuery := ""
if pm.Channel != "" {
connJoinQuery = "JOIN connection conn ON conn.client_id = c.id"
}
cq := fmt.Sprintf(`%s
SELECT COUNT(*) AS total_count
FROM (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
c.parent_group_path,
c.role_id,
c.role_name,
c.actions,
c.access_type,
c.access_provider_id,
c.access_provider_role_id,
c.access_provider_role_name,
c.access_provider_role_actions
FROM
final_clients c
%s
%s
) AS subquery;
`, bq, connJoinQuery, pageQuery)
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
Clients: items,
Page: clients.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
}
func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
return fmt.Sprintf(`
WITH direct_clients AS (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
text2ltree('') as parent_group_path,
cr.id AS role_id,
cr."name" AS role_name,
array_agg(cra."action") AS actions,
'direct' as access_type,
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
FROM
clients_role_members crm
JOIN
clients_role_actions cra ON cra.role_id = crm.role_id
JOIN
clients_roles cr ON cr.id = crm.role_id
JOIN
clients c ON c.id = cr.entity_id
WHERE
crm.member_id = '%s'
AND c.domain_id = '%s'
GROUP BY
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
),
direct_groups AS (
SELECT
g.*,
gr.entity_id AS entity_id,
grm.member_id AS member_id,
gr.id AS role_id,
gr."name" AS role_name,
array_agg(gra."action") AS actions
FROM
groups_role_members grm
JOIN
groups_role_actions gra ON gra.role_id = grm.role_id
JOIN
groups_roles gr ON gr.id = grm.role_id
JOIN
"groups" g ON g.id = gr.entity_id
WHERE
grm.member_id = '%s'
AND g.domain_id = '%s'
GROUP BY
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
),
direct_groups_with_subgroup AS (
SELECT
*
FROM direct_groups
WHERE EXISTS (
SELECT 1
FROM unnest(direct_groups.actions) AS action
WHERE action LIKE 'subgroup_%%'
)
),
indirect_child_groups AS (
SELECT
DISTINCT indirect_child_groups.id as child_id,
indirect_child_groups.*,
dgws.id as access_provider_id,
dgws.role_id as access_provider_role_id,
dgws.role_name as access_provider_role_name,
dgws.actions as access_provider_role_actions
FROM
direct_groups_with_subgroup dgws
JOIN
groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path
WHERE
indirect_child_groups.domain_id = '%s'
AND NOT EXISTS (
SELECT 1
FROM direct_groups_with_subgroup dgws
WHERE dgws.id = indirect_child_groups.id
)
),
final_groups AS (
SELECT
id,
parent_id,
domain_id,
"name",
description,
metadata,
created_at,
updated_at,
updated_by,
status,
"path",
role_id,
role_name,
actions,
'direct_group' AS access_type,
'' AS access_provider_id,
'' AS access_provider_role_id,
'' AS access_provider_role_name,
array[]::::text[] AS access_provider_role_actions
FROM
direct_groups
UNION
SELECT
id,
parent_id,
domain_id,
"name",
description,
metadata,
created_at,
updated_at,
updated_by,
status,
"path",
'' AS role_id,
'' AS role_name,
array[]::::text[] AS actions,
'indirect_group' AS access_type,
access_provider_id,
access_provider_role_id,
access_provider_role_name,
access_provider_role_actions
FROM
indirect_child_groups
),
group_direct_clients AS (
SELECT
c.id,
c.name,
c.domain_id,
c.parent_group_id,
c.identity,
c.secret,
c.tags,
c.metadata,
c.created_at,
c.updated_at,
c.updated_by,
c.status,
g.path AS parent_group_path,
g.role_id,
g.role_name,
g.actions,
g.access_type,
g.access_provider_id,
g.access_provider_role_id,
g.access_provider_role_name,
g.access_provider_role_actions
FROM
final_groups g
JOIN
clients c ON c.parent_group_id = g.id
WHERE
c.id NOT IN (SELECT id FROM direct_clients)
UNION
SELECT
dc.id,
dc.name,
dc.domain_id,
dc.parent_group_id,
dc.identity,
dc.secret,
dc.tags,
dc.metadata,
dc.created_at,
dc.updated_at,
dc.updated_by,
dc.status,
dc.parent_group_path,
dc.role_id,
dc.role_name,
dc.actions,
dc.access_type,
dc.access_provider_id,
dc.access_provider_role_id,
dc.access_provider_role_name,
dc.access_provider_role_actions
FROM
direct_clients AS dc
),
final_clients AS (
SELECT
gdc.id,
gdc.name,
gdc.domain_id,
gdc.parent_group_id,
gdc.identity,
gdc.secret,
gdc.tags,
gdc.metadata,
gdc.created_at,
gdc.updated_at,
gdc.updated_by,
gdc.status,
gdc.parent_group_path,
gdc.role_id,
gdc.role_name,
gdc.actions,
gdc.access_type,
gdc.access_provider_id,
gdc.access_provider_role_id,
gdc.access_provider_role_name,
gdc.access_provider_role_actions
FROM
group_direct_clients AS gdc
)
`, userID, domainID, userID, domainID, domainID)
}
func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
query, err := PageQuery(pm)
if err != nil {
@@ -302,64 +647,6 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
return page, nil
}
func (repo *clientRepo) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
if (len(pm.IDs) == 0) && (pm.Domain == "") {
return clients.ClientsPage{
Page: clients.Page{Total: pm.Total, Offset: pm.Offset, Limit: pm.Limit},
}, nil
}
query, err := PageQuery(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
query = applyOrdering(query, pm)
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
dbPage, err := ToDBClientsPage(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
var items []clients.Client
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
if err != nil {
return clients.ClientsPage{}, err
}
items = append(items, c)
}
cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, query)
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
Clients: items,
Page: clients.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
},
}
return page, nil
}
func (repo *clientRepo) update(ctx context.Context, client clients.Client, query string) (clients.Client, error) {
dbc, err := ToDBClient(client)
if err != nil {
@@ -402,18 +689,27 @@ func (repo *clientRepo) Delete(ctx context.Context, clientIDs ...string) error {
}
type DBClient struct {
ID string `db:"id"`
Name string `db:"name,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Identity string `db:"identity"`
Domain string `db:"domain_id"`
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
Secret string `db:"secret"`
Metadata []byte `db:"metadata,omitempty"`
CreatedAt time.Time `db:"created_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status clients.Status `db:"status,omitempty"`
ID string `db:"id"`
Name string `db:"name,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Identity string `db:"identity"`
Domain string `db:"domain_id"`
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
Secret string `db:"secret"`
Metadata []byte `db:"metadata,omitempty"`
CreatedAt time.Time `db:"created_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status clients.Status `db:"status,omitempty"`
ParentGroupPath string `db:"parent_group_path,omitempty"`
RoleID string `db:"role_id,omitempty"`
RoleName string `db:"role_name,omitempty"`
Actions pq.StringArray `db:"actions,omitempty"`
AccessType string `db:"access_type,omitempty"`
AccessProviderId string `db:"access_provider_id,omitempty"`
AccessProviderRoleId string `db:"access_provider_role_id,omitempty"`
AccessProviderRoleName string `db:"access_provider_role_name,omitempty"`
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"`
}
func ToDBClient(c clients.Client) (DBClient, error) {
@@ -484,11 +780,19 @@ func ToClient(t DBClient) (clients.Client, error) {
Identity: t.Identity,
Secret: t.Secret,
},
Metadata: metadata,
CreatedAt: t.CreatedAt,
UpdatedAt: updatedAt,
UpdatedBy: updatedBy,
Status: t.Status,
Metadata: metadata,
CreatedAt: t.CreatedAt,
UpdatedAt: updatedAt,
UpdatedBy: updatedBy,
Status: t.Status,
RoleID: t.RoleID,
RoleName: t.RoleName,
Actions: t.Actions,
AccessType: t.AccessType,
AccessProviderId: t.AccessProviderId,
AccessProviderRoleId: t.AccessProviderRoleId,
AccessProviderRoleName: t.AccessProviderRoleName,
AccessProviderRoleActions: t.AccessProviderRoleActions,
}
return cli, nil
}
@@ -499,31 +803,42 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) {
return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return dbClientsPage{
Name: pm.Name,
Identity: pm.Identity,
Id: pm.Id,
Metadata: data,
Domain: pm.Domain,
Total: pm.Total,
Offset: pm.Offset,
Limit: pm.Limit,
Status: pm.Status,
Tag: pm.Tag,
Offset: pm.Offset,
Limit: pm.Limit,
Name: pm.Name,
Identity: pm.Identity,
Id: pm.Id,
Metadata: data,
Domain: pm.Domain,
Status: pm.Status,
Tag: pm.Tag,
GroupID: pm.Group,
ChannelID: pm.Channel,
RoleName: pm.RoleName,
ConnType: pm.ConnectionType,
RoleID: pm.RoleID,
Actions: pm.Actions,
AccessType: pm.AccessType,
}, nil
}
type dbClientsPage struct {
Total uint64 `db:"total"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Identity string `db:"identity"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status clients.Status `db:"status"`
GroupID string `db:"group_id"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Identity string `db:"identity"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status clients.Status `db:"status"`
GroupID string `db:"group_id"`
ChannelID string `db:"channel_id"`
ConnType string `db:"type"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
}
func PageQuery(pm clients.Page) (string, error) {
@@ -534,36 +849,57 @@ func PageQuery(pm clients.Page) (string, error) {
var query []string
if pm.Name != "" {
query = append(query, "name ILIKE '%' || :name || '%'")
query = append(query, "c.name ILIKE '%' || :name || '%'")
}
if pm.Identity != "" {
query = append(query, "identity ILIKE '%' || :identity || '%'")
query = append(query, "c.identity ILIKE '%' || :identity || '%'")
}
if pm.Id != "" {
query = append(query, "id ILIKE '%' || :id || '%'")
query = append(query, "c.id ILIKE '%' || :id || '%'")
}
if pm.Tag != "" {
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
}
// If there are search params presents, use search and ignore other options.
// Always combine role with search params, so len(query) > 1.
if len(query) > 1 {
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil
}
if mq != "" {
query = append(query, mq)
}
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
query = append(query, fmt.Sprintf("c.id IN ('%s')", strings.Join(pm.IDs, "','")))
}
if pm.Status != clients.AllStatus {
query = append(query, "c.status = :status")
}
if pm.Domain != "" {
query = append(query, "c.domain_id = :domain_id")
}
if pm.Group != "" {
query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ")
}
if pm.Channel != "" {
query = append(query, "conn.channel_id = :channel_id ")
if pm.ConnectionType != "" {
query = append(query, "conn.type = :conn_type ")
}
}
if pm.AccessType != "" {
query = append(query, "c.access_type = :access_type")
}
if pm.RoleID != "" {
query = append(query, "c.role_id = :role_id")
}
if pm.RoleName != "" {
query = append(query, "c.role_name = :role_name")
}
if len(pm.Actions) != 0 {
query = append(query, "c.actions @> :actions")
}
if len(pm.Metadata) > 0 {
query = append(query, "c.metadata @> :metadata")
}
var emq string
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
+1 -268
View File
@@ -1626,274 +1626,7 @@ func TestSearchClients(t *testing.T) {
}
}
func TestRetrieveAllByIDs(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM clients")
require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
num := 200
var items []clients.Client
for i := 0; i < num; i++ {
name := namegen.Generate()
client := clients.Client{
ID: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
Name: name,
Credentials: clients.Credentials{
Identity: name + emailSuffix,
Secret: testsutil.GenerateUUID(t),
},
Tags: namegen.GenerateMultiple(5),
Metadata: map[string]interface{}{"name": name},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
}
_, err := repo.Save(context.Background(), client)
require.Nil(t, err, fmt.Sprintf("add new client: expected nil got %s\n", err))
items = append(items, client)
}
page, err := repo.RetrieveAll(context.Background(), clients.Page{Offset: 0, Limit: uint64(num)})
require.Nil(t, err, fmt.Sprintf("retrieve all clients unexpected error: %s", err))
assert.Equal(t, uint64(num), page.Total)
cases := []struct {
desc string
page clients.Page
response clients.ClientsPage
err error
}{
{
desc: "successfully",
page: clients.Page{
Offset: 0,
Limit: 10,
IDs: getIDs(items[0:3]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 3,
Offset: 0,
Limit: 10,
},
Clients: items[0:3],
},
err: nil,
},
{
desc: "with empty ids",
page: clients.Page{
Offset: 0,
Limit: 10,
IDs: []string{},
},
response: clients.ClientsPage{
Page: clients.Page{
Offset: 0,
Limit: 10,
},
Clients: []clients.Client(nil),
},
err: nil,
},
{
desc: "with empty ids but with domain id",
page: clients.Page{
Offset: 0,
Limit: 10,
Domain: items[0].Domain,
IDs: []string{},
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 1,
Offset: 0,
Limit: 10,
},
Clients: []clients.Client{items[0]},
},
err: nil,
},
{
desc: "with offset only",
page: clients.Page{
Offset: 10,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 20,
Offset: 10,
Limit: 0,
},
Clients: []clients.Client(nil),
},
err: nil,
},
{
desc: "with limit only",
page: clients.Page{
Limit: 10,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 20,
Offset: 0,
Limit: 10,
},
Clients: items[0:10],
},
err: nil,
},
{
desc: "with offset out of range",
page: clients.Page{
Offset: 1000,
Limit: 50,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 20,
Offset: 1000,
Limit: 50,
},
Clients: []clients.Client(nil),
},
err: nil,
},
{
desc: "with offset and limit out of range",
page: clients.Page{
Offset: 15,
Limit: 10,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 20,
Offset: 15,
Limit: 10,
},
Clients: items[15:20],
},
err: nil,
},
{
desc: "with limit out of range",
page: clients.Page{
Offset: 0,
Limit: 1000,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 20,
Offset: 0,
Limit: 1000,
},
Clients: items[:20],
},
err: nil,
},
{
desc: "with name",
page: clients.Page{
Offset: 0,
Limit: 10,
Name: items[0].Name,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 1,
Offset: 0,
Limit: 10,
},
Clients: []clients.Client{items[0]},
},
err: nil,
},
{
desc: "with domain id",
page: clients.Page{
Offset: 0,
Limit: 10,
Domain: items[0].Domain,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 1,
Offset: 0,
Limit: 10,
},
Clients: []clients.Client{items[0]},
},
err: nil,
},
{
desc: "with metadata",
page: clients.Page{
Offset: 0,
Limit: 10,
Metadata: items[0].Metadata,
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 1,
Offset: 0,
Limit: 10,
},
Clients: []clients.Client{items[0]},
},
err: nil,
},
{
desc: "with invalid metadata",
page: clients.Page{
Offset: 0,
Limit: 10,
Metadata: map[string]interface{}{
"key": make(chan int),
},
IDs: getIDs(items[0:20]),
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 0,
Offset: 0,
Limit: 10,
},
Clients: []clients.Client(nil),
},
err: errors.ErrMalformedEntity,
},
}
for _, c := range cases {
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
case err == nil:
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
assert.Equal(t, c.response.Total, response.Total)
assert.Equal(t, c.response.Limit, response.Limit)
assert.Equal(t, c.response.Offset, response.Offset)
expected := stripClientDetails(c.response.Clients)
got := stripClientDetails(response.Clients)
assert.ElementsMatch(t, expected, got)
default:
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
}
}
}
func TestRetrievByIDs(t *testing.T) {
func TestRetrieveByIDs(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM clients")
require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err))
+8
View File
@@ -4,6 +4,7 @@
package postgres
import (
gpostgres "github.com/absmach/supermq/groups/postgres"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
@@ -60,5 +61,12 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
clientsMigration.Migrations = append(clientsMigration.Migrations, clientsRolesMigration.Migrations...)
groupsMigration, err := gpostgres.Migration()
if err != nil {
return &migrate.MemoryMigrationSource{}, err
}
clientsMigration.Migrations = append(clientsMigration.Migrations, groupsMigration.Migrations...)
return clientsMigration, nil
}
+1 -1
View File
@@ -144,7 +144,7 @@ func NewRolesOperationPermissionMap() map[svcutil.Operation]svcutil.Permission {
const (
// External Permission for domains.
domainCreateClientPermission = "client_create_permission"
domainListClientsPermission = "list_clients_permission"
domainListClientsPermission = "client_read_permission"
// External Permission for groups.
groupSetChildClientPermission = "client_create_permission"
groupRemoveChildClientPermission = "client_create_permission"
+14 -100
View File
@@ -12,13 +12,11 @@ import (
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
apiutil "github.com/absmach/supermq/api/http/util"
smqauth "github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
"golang.org/x/sync/errgroup"
)
var (
@@ -131,113 +129,29 @@ func (svc service) View(ctx context.Context, session authn.Session, id string) (
return client, nil
}
func (svc service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error) {
var ids []string
var err error
switch {
case (reqUserID != "" && reqUserID != session.UserID):
rtids, err := svc.listClientIDs(ctx, smqauth.EncodeDomainUserID(session.DomainID, reqUserID), pm.Permission)
func (svc service) ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error) {
switch session.SuperAdmin {
case true:
cp, err := svc.repo.RetrieveAll(ctx, pm)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
ids, err = svc.filterAllowedClientIDs(ctx, session.DomainUserID, pm.Permission, rtids)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
default:
switch session.SuperAdmin {
case true:
pm.Domain = session.DomainID
default:
ids, err = svc.listClientIDs(ctx, session.DomainUserID, pm.Permission)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, session.UserID, pm)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
}
}
if len(ids) == 0 && pm.Domain == "" {
return ClientsPage{}, nil
}
pm.IDs = ids
tp, err := svc.repo.SearchClients(ctx, pm)
func (svc service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error) {
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, userID, pm)
if err != nil {
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if pm.ListPerms && len(tp.Clients) > 0 {
g, ctx := errgroup.WithContext(ctx)
for i := range tp.Clients {
// Copying loop variable "i" to avoid "loop variable captured by func literal"
iter := i
g.Go(func() error {
return svc.retrievePermissions(ctx, session.DomainUserID, &tp.Clients[iter])
})
}
if err := g.Wait(); err != nil {
return ClientsPage{}, err
}
}
return tp, nil
}
// Experimental functions used for async calling of svc.listUserClientPermission. This might be helpful during listing of large number of entities.
func (svc service) retrievePermissions(ctx context.Context, userID string, client *Client) error {
permissions, err := svc.listUserClientPermission(ctx, userID, client.ID)
if err != nil {
return err
}
client.Permissions = permissions
return nil
}
func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) {
permissions, err := svc.policy.ListPermissions(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: userID,
Object: clientID,
ObjectType: policies.ClientType,
}, []string{})
if err != nil {
return []string{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return permissions, nil
}
func (svc service) listClientIDs(ctx context.Context, userID, permission string) ([]string, error) {
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: userID,
Permission: permission,
ObjectType: policies.ClientType,
})
if err != nil {
return nil, errors.Wrap(svcerr.ErrNotFound, err)
}
return tids.Policies, nil
}
func (svc service) filterAllowedClientIDs(ctx context.Context, userID, permission string, clientIDs []string) ([]string, error) {
var ids []string
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: userID,
Permission: permission,
ObjectType: policies.ClientType,
})
if err != nil {
return nil, errors.Wrap(svcerr.ErrNotFound, err)
}
for _, clientID := range clientIDs {
for _, tid := range tids.Policies {
if clientID == tid {
ids = append(ids, clientID)
}
}
}
return ids, nil
return cp, nil
}
func (svc service) Update(ctx context.Context, session authn.Session, cli Client) (Client, error) {
+24 -99
View File
@@ -351,7 +351,6 @@ func TestListClients(t *testing.T) {
adminID := testsutil.GenerateUUID(t)
domainID := testsutil.GenerateUUID(t)
nonAdminID := testsutil.GenerateUUID(t)
client.Permissions = []string{"read", "edit"}
cases := []struct {
desc string
@@ -375,9 +374,8 @@ func TestListClients(t *testing.T) {
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Offset: 0,
Limit: 100,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
@@ -388,7 +386,6 @@ func TestListClients(t *testing.T) {
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: client.Permissions,
response: clients.ClientsPage{
Page: clients.Page{
Total: 2,
@@ -405,9 +402,8 @@ func TestListClients(t *testing.T) {
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Offset: 0,
Limit: 100,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{},
@@ -415,39 +411,14 @@ func TestListClients(t *testing.T) {
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as non admin with failed to list permissions",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: []string{},
response: clients.ClientsPage{},
listPermissionsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as non admin with failed super admin",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Offset: 0,
Limit: 100,
},
response: clients.ClientsPage{},
listObjectsResponse: policysvc.PolicyPage{},
@@ -458,10 +429,10 @@ func TestListClients(t *testing.T) {
userKind: "non-admin",
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Offset: 0,
Limit: 100,
},
retrieveAllErr: repoerr.ErrNotFound,
response: clients.ClientsPage{},
listObjectsResponse: policysvc.PolicyPage{},
listObjectsErr: svcerr.ErrNotFound,
@@ -470,15 +441,13 @@ func TestListClients(t *testing.T) {
}
for _, tc := range cases {
listAllObjectsCall := pService.On("ListAllObjects", mock.Anything, mock.Anything).Return(tc.listObjectsResponse, tc.listObjectsErr)
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page)
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveUserClientsCall := repo.On("RetrieveUserClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
listAllObjectsCall.Unset()
retrieveAllCall.Unset()
listPermissionsCall.Unset()
retrieveUserClientsCall.Unset()
}
cases2 := []struct {
@@ -503,10 +472,9 @@ func TestListClients(t *testing.T) {
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Domain: domainID,
Offset: 0,
Limit: 100,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
@@ -517,7 +485,6 @@ func TestListClients(t *testing.T) {
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: client.Permissions,
response: clients.ClientsPage{
Page: clients.Page{
Total: 2,
@@ -534,50 +501,24 @@ func TestListClients(t *testing.T) {
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Domain: domainID,
Offset: 0,
Limit: 100,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{},
retrieveAllResponse: clients.ClientsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list permissions",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
listPermissionsResponse: []string{},
listPermissionsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list clients",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
ListPerms: true,
Domain: domainID,
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: clients.ClientsPage{},
retrieveAllErr: repoerr.ErrNotFound,
@@ -586,27 +527,11 @@ func TestListClients(t *testing.T) {
}
for _, tc := range cases2 {
listAllObjectsCall := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: tc.session.DomainID + "_" + adminID,
Permission: "",
ObjectType: policysvc.ClientType,
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
listAllObjectsCall2 := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
SubjectType: policysvc.UserType,
Subject: tc.session.UserID,
Permission: "",
ObjectType: policysvc.ClientType,
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page)
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
listAllObjectsCall.Unset()
listAllObjectsCall2.Unset()
retrieveAllCall.Unset()
listPermissionsCall.Unset()
}
}
+1 -1
View File
@@ -197,7 +197,7 @@ func TestStatusUnmarshalJSON(t *testing.T) {
}
}
func TestUserMarshalJSON(t *testing.T) {
func TestClientMarshalJSON(t *testing.T) {
cases := []struct {
desc string
expected []byte
+8 -2
View File
@@ -47,10 +47,16 @@ func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id
}
// ListClients traces the "ListClients" operation of the wrapped clients.Service.
func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_clients")
defer span.End()
return tm.svc.ListClients(ctx, session, reqUserID, pm)
return tm.svc.ListClients(ctx, session, pm)
}
func (tm *tracingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_clients")
defer span.End()
return tm.svc.ListUserClients(ctx, session, userID, pm)
}
// Update traces the "Update" operation of the wrapped clients.Service.
+12
View File
@@ -25,11 +25,13 @@ import (
"github.com/absmach/supermq/channels/postgres"
pChannels "github.com/absmach/supermq/channels/private"
"github.com/absmach/supermq/channels/tracing"
gpostgres "github.com/absmach/supermq/groups/postgres"
smqlog "github.com/absmach/supermq/logger"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/policies"
@@ -74,6 +76,7 @@ type config struct {
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"channels"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
@@ -224,6 +227,15 @@ func main() {
return
}
gdatabase := pg.NewDatabase(db, dbConfig, tracer)
grepo := gpostgres.New(gdatabase)
if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
logger.Error(fmt.Sprintf("failed to create groups event store : %s", err))
exitCode = 1
return
}
grpcServerConfig := server.Config{Port: defSvcGRPCPort}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err))
+12
View File
@@ -27,12 +27,14 @@ import (
"github.com/absmach/supermq/clients/postgres"
pClients "github.com/absmach/supermq/clients/private"
"github.com/absmach/supermq/clients/tracing"
gpostgres "github.com/absmach/supermq/groups/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/policies"
@@ -82,6 +84,7 @@ type config struct {
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"clients"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
@@ -241,6 +244,15 @@ func main() {
return
}
gdatabase := pg.NewDatabase(db, dbConfig, tracer)
grepo := gpostgres.New(gdatabase)
if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
logger.Error(fmt.Sprintf("failed to create groups event store : %s", err))
exitCode = 1
return
}
httpServerConfig := server.Config{Port: defSvcHTTPPort}
if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err))
+7 -7
View File
@@ -357,13 +357,13 @@ type addChildrenGroupsEvent struct {
func (acge addChildrenGroupsEvent) Encode() (map[string]interface{}, error) {
return map[string]interface{}{
"operation": groupAddChildrenGroups,
"id": acge.id,
"childre_ids": acge.childrenIDs,
"domain": acge.DomainID,
"user_id": acge.UserID,
"token_type": acge.Type.String(),
"super_admin": acge.SuperAdmin,
"operation": groupAddChildrenGroups,
"id": acge.id,
"children_ids": acge.childrenIDs,
"domain": acge.DomainID,
"user_id": acge.UserID,
"token_type": acge.Type.String(),
"super_admin": acge.SuperAdmin,
}, nil
}
+2 -1
View File
@@ -142,9 +142,10 @@ type Service interface {
// ViewGroup retrieves data about the group identified by ID.
ViewGroup(ctx context.Context, session authn.Session, id string) (Group, error)
// ListGroups retrieves
// ListGroups retrieves groups for given filters.
ListGroups(ctx context.Context, session authn.Session, pm PageMeta) (Page, error)
// ListGroups retrieves user accessible groups for given filters.
ListUserGroups(ctx context.Context, session authn.Session, userID string, pm PageMeta) (Page, error)
// EnableGroup logically enables the group identified with the provided ID.
+1
View File
@@ -73,6 +73,7 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups.
}
return &authorizationMiddleware{
svc: svc,
repo: repo,
authz: authz,
opp: opp,
extOpp: extOpp,
+2 -1
View File
@@ -34,7 +34,8 @@ var (
// ErrFailedToRetrieveAllGroups failed to retrieve groups.
ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups")
ErrRoleMigration = errors.New("role migration initialization failed")
// ErrRoleMigration failed to apply role migrations.
ErrRoleMigration = errors.New("failed to apply role migration")
// ErrMissingNames indicates missing first and last names.
ErrMissingNames = errors.New("missing first or last name")
+1
View File
@@ -43,6 +43,7 @@ type SubscriberConfig struct {
Consumer string
Stream string
Handler EventHandler
Ordered bool
}
// Subscriber specifies event subscription API.
+4 -2
View File
@@ -93,6 +93,7 @@ func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberCon
logger: es.logger,
},
DeliveryPolicy: messaging.DeliverNewPolicy,
Ordered: cfg.Ordered,
}
return es.pubsub.Subscribe(ctx, subCfg)
@@ -126,8 +127,9 @@ func (eh *eventHandler) Handle(msg *messaging.Message) error {
return err
}
if err := eh.handler.Handle(eh.ctx, event); err != nil {
eh.logger.Warn(fmt.Sprintf("failed to handle nats event: %s", err))
err := eh.handler.Handle(eh.ctx, event)
if err != nil {
return fmt.Errorf("failed to handle nats event: %s", err)
}
return nil
+255
View File
@@ -0,0 +1,255 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package consumer
import (
"time"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/roles"
rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer"
)
var (
errDecodeCreateGroupEvent = errors.New("failed to decode group create event")
errDecodeUpdateGroupEvent = errors.New("failed to decode group update event")
errDecodeChangeStatusGroupEvent = errors.New("failed to decode group change status event")
errDecodeRemoveGroupEvent = errors.New("failed to decode group remove event")
errDecodeAddParentGroupEvent = errors.New("failed to decode group add parent event")
errDecodeRemoveParentGroupEvent = errors.New("failed to decode group remove parent event")
errDecodeAddChildrenGroupsEvent = errors.New("failed to decode group add children groups event")
errDecodeRemoveChildrenGroupsEvent = errors.New("failed to decode group remove children groups event")
errID = errors.New("missing or invalid 'id'")
errName = errors.New("missing or invalid 'name'")
errDomain = errors.New("missing or invalid 'domain'")
errParent = errors.New("missing or invalid 'parent'")
errChildrenIDs = errors.New("missing or invalid 'children_ids'")
errStatus = errors.New("missing or invalid 'status'")
errConvertStatus = errors.New("failed to convert status")
errCreatedAt = errors.New("failed to parse 'created_at' time")
errUpdatedAt = errors.New("failed to parse 'updated_at' time")
)
const (
layout = "2006-01-02T15:04:05.999999Z"
)
func ToGroups(data map[string]interface{}) (groups.Group, error) {
var g groups.Group
id, ok := data["id"].(string)
if !ok {
return groups.Group{}, errID
}
g.ID = id
name, ok := data["name"].(string)
if !ok {
return groups.Group{}, errName
}
g.Name = name
dom, ok := data["domain"].(string)
if !ok {
return groups.Group{}, errDomain
}
g.Domain = dom
stat, ok := data["status"].(string)
if !ok {
return groups.Group{}, errStatus
}
st, err := groups.ToStatus(stat)
if err != nil {
return groups.Group{}, errors.Wrap(errConvertStatus, err)
}
g.Status = st
cat, ok := data["created_at"].(string)
if !ok {
return groups.Group{}, errCreatedAt
}
ct, err := time.Parse(layout, cat)
if err != nil {
return groups.Group{}, errors.Wrap(errCreatedAt, err)
}
g.CreatedAt = ct
// Following fields of groups are allowed to be empty.
desc, ok := data["description"].(string)
if ok {
g.Description = desc
}
parent, ok := data["parent"].(string)
if ok {
g.Parent = parent
}
meta, ok := data["metadata"].(map[string]interface{})
if ok {
g.Metadata = meta
}
uby, ok := data["updated_by"].(string)
if ok {
g.UpdatedBy = uby
}
uat, ok := data["updated_at"].(string)
if ok {
ut, err := time.Parse(layout, uat)
if err != nil {
return groups.Group{}, errors.Wrap(errUpdatedAt, err)
}
g.UpdatedAt = ut
}
return g, nil
}
func decodeCreateGroupEvent(data map[string]interface{}) (groups.Group, []roles.RoleProvision, error) {
g, err := ToGroups(data)
if err != nil {
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err)
}
irps, ok := data["roles_provisioned"].([]interface{})
if !ok {
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, errors.New("missing or invalid 'roles_provisioned'"))
}
rps, err := rconsumer.ToRoleProvisions(irps)
if err != nil {
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err)
}
return g, rps, nil
}
func decodeUpdateGroupEvent(data map[string]interface{}) (groups.Group, error) {
g, err := ToGroups(data)
if err != nil {
return groups.Group{}, errors.Wrap(errDecodeUpdateGroupEvent, err)
}
return g, nil
}
func ToGroupStatus(data map[string]interface{}) (groups.Group, error) {
var g groups.Group
id, ok := data["id"].(string)
if !ok {
return groups.Group{}, errID
}
g.ID = id
stat, ok := data["status"].(string)
if !ok {
return groups.Group{}, errStatus
}
st, err := groups.ToStatus(stat)
if err != nil {
return groups.Group{}, errors.Wrap(errConvertStatus, err)
}
g.Status = st
uat, ok := data["updated_at"].(string)
if ok {
ut, err := time.Parse(layout, uat)
if err != nil {
return groups.Group{}, errors.Wrap(errUpdatedAt, err)
}
g.UpdatedAt = ut
}
uby, ok := data["updated_by"].(string)
if ok {
g.UpdatedBy = uby
}
return g, nil
}
func decodeChangeStatusGroupEvent(data map[string]interface{}) (groups.Group, error) {
g, err := ToGroupStatus(data)
if err != nil {
return groups.Group{}, errors.Wrap(errDecodeChangeStatusGroupEvent, err)
}
return g, nil
}
func decodeRemoveGroupEvent(data map[string]interface{}) (groups.Group, error) {
var g groups.Group
id, ok := data["id"].(string)
if !ok {
return groups.Group{}, errors.Wrap(errDecodeRemoveGroupEvent, errID)
}
g.ID = id
return g, nil
}
func decodeAddParentGroupEvent(data map[string]interface{}) (id string, parent string, err error) {
id, ok := data["id"].(string)
if !ok {
return "", "", errors.Wrap(errAddParentGroupEvent, errID)
}
parent, ok = data["parent_id"].(string)
if !ok {
return "", "", errors.Wrap(errDecodeAddParentGroupEvent, errParent)
}
return id, parent, nil
}
func decodeRemoveParentGroupEvent(data map[string]interface{}) (id string, err error) {
id, ok := data["id"].(string)
if !ok {
return "", errors.Wrap(errDecodeRemoveParentGroupEvent, errID)
}
return id, nil
}
func decodeAddChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) {
id, ok := data["id"].(string)
if !ok {
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errID)
}
chIDs, ok := data["children_ids"].([]interface{})
if !ok {
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errChildrenIDs)
}
cids, err := rconsumer.ToStrings(chIDs)
if err != nil {
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err))
}
return id, cids, nil
}
func decodeRemoveChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) {
id, ok := data["id"].(string)
if !ok {
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID)
}
chIDs, ok := data["children_ids"].([]interface{})
if !ok {
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errChildrenIDs)
}
cids, err := rconsumer.ToStrings(chIDs)
if err != nil {
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err))
}
return id, cids, nil
}
func decodeRemoveAllChildrenGroupEvent(data map[string]interface{}) (id string, err error) {
id, ok := data["id"].(string)
if !ok {
return "", errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID)
}
return id, nil
}
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package consumer contains events consumer for events
// published by Bootstrap service.
package consumer
+253
View File
@@ -0,0 +1,253 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package consumer
import (
"context"
"fmt"
"log/slog"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/events"
"github.com/absmach/supermq/pkg/events/store"
rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer"
)
const (
stream = "events.supermq.groups"
create = "group.create"
update = "group.update"
changeStatus = "group.change_status"
remove = "group.remove"
addParentGroup = "group.add_parent_group"
removeParentGroup = "group.remove_parent_group"
addChildrenGroups = "group.add_children_groups"
removeChildrenGroups = "group.remove_children_groups"
removeAllChildrenGroups = "group.remove_all_children_groups"
addRole = "group.role.add"
removeRole = "group.role.remove"
updateRole = "group.role.update"
addRoleActions = "group.role.actions.add"
removeRoleActions = "group.role.actions.remove"
removeAllRoleActions = "group.role.actions.remove_all"
addRoleMembers = "group.role.members.add"
removeRoleMembers = "group.role.members.remove"
removeRoleAllMembers = "group.role.members.remove_all"
removeMemberFromAllRoles = "group.role.members.remove_from_all_roles"
)
var (
errNoOperationKey = errors.New("operation key is not found in event message")
errCreateGroupEvent = errors.New("failed to consume group create event")
errUpdateGroupEvent = errors.New("failed to consume group update event")
errChangeStatusGroupEvent = errors.New("failed to consume group change status event")
errRemoveGroupEvent = errors.New("failed to consume group remove event")
errAddParentGroupEvent = errors.New("failed to consume group add parent group event")
errRemoveParentGroupEvent = errors.New("failed to consume group remove parent group event")
errAddChildrenGroupEvent = errors.New("failed to consume group add children groups event")
errRemoveChildrenGroupEvent = errors.New("failed to consume group remove children groups event")
errRemoveAllChildrenGroupEvent = errors.New("failed to consume group remove all children groups event")
)
type eventHandler struct {
repo groups.Repository
rolesEventHandler rconsumer.EventHandler
}
func GroupsEventsSubscribe(ctx context.Context, repo groups.Repository, esURL, esConsumerName string, logger *slog.Logger) error {
subscriber, err := store.NewSubscriber(ctx, esURL, logger)
if err != nil {
return err
}
subConfig := events.SubscriberConfig{
Stream: stream,
Consumer: esConsumerName,
Handler: NewEventHandler(repo),
Ordered: true,
}
return subscriber.Subscribe(ctx, subConfig)
}
// NewEventHandler returns new event store handler.
func NewEventHandler(repo groups.Repository) events.EventHandler {
reh := rconsumer.NewEventHandler("group", repo)
return &eventHandler{
repo: repo,
rolesEventHandler: reh,
}
}
func (es *eventHandler) Handle(ctx context.Context, event events.Event) error {
msg, err := event.Encode()
if err != nil {
return err
}
op, ok := msg["operation"]
if !ok {
return errNoOperationKey
}
switch op {
case create:
return es.createGroupHandler(ctx, msg)
case update:
return es.updateGroupHandler(ctx, msg)
case changeStatus:
return es.changeStatusGroupHandler(ctx, msg)
case remove:
return es.removeGroupHandler(ctx, msg)
case addParentGroup:
return es.addParentGroupHandler(ctx, msg)
case removeParentGroup:
return es.removeParentGroupHandler(ctx, msg)
case addChildrenGroups:
return es.addChildrenGroupsHandler(ctx, msg)
case removeChildrenGroups:
return es.removeChildrenGroupsHandler(ctx, msg)
case removeAllChildrenGroups:
return es.removeAllChildrenGroupsHandler(ctx, msg)
case addRole:
return es.rolesEventHandler.AddEntityRoleHandler(ctx, msg)
case updateRole:
return es.rolesEventHandler.UpdateEntityRoleHandler(ctx, msg)
case removeRole:
return es.rolesEventHandler.RemoveEntityRoleHandler(ctx, msg)
case addRoleActions:
return es.rolesEventHandler.AddEntityRoleActionsHandler(ctx, msg)
case removeRoleActions:
return es.rolesEventHandler.RemoveEntityRoleActionsHandler(ctx, msg)
case removeAllRoleActions:
return es.rolesEventHandler.RemoveAllEntityRoleActionsHandler(ctx, msg)
case addRoleMembers:
return es.rolesEventHandler.AddEntityRoleMembersHandler(ctx, msg)
case removeRoleMembers:
return es.rolesEventHandler.RemoveEntityRoleMembersHandler(ctx, msg)
case removeRoleAllMembers:
return es.rolesEventHandler.RemoveAllEntityRoleMembersHandler(ctx, msg)
case removeMemberFromAllRoles:
return es.rolesEventHandler.RemoveMemberFromAllEntityHandler(ctx, msg)
}
return nil
}
func (es *eventHandler) createGroupHandler(ctx context.Context, data map[string]interface{}) error {
g, rps, err := decodeCreateGroupEvent(data)
if err != nil {
return errors.Wrap(errCreateGroupEvent, err)
}
if _, err := es.repo.Save(ctx, g); err != nil {
return errors.Wrap(errCreateGroupEvent, err)
}
if _, err := es.repo.AddRoles(ctx, rps); err != nil {
return errors.Wrap(errCreateGroupEvent, err)
}
return nil
}
func (es *eventHandler) updateGroupHandler(ctx context.Context, data map[string]interface{}) error {
g, err := decodeUpdateGroupEvent(data)
if err != nil {
return errors.Wrap(errUpdateGroupEvent, err)
}
if _, err := es.repo.Update(ctx, g); err != nil {
return errors.Wrap(errUpdateGroupEvent, err)
}
return nil
}
func (es *eventHandler) changeStatusGroupHandler(ctx context.Context, data map[string]interface{}) error {
g, err := decodeChangeStatusGroupEvent(data)
if err != nil {
return errors.Wrap(errChangeStatusGroupEvent, err)
}
if _, err := es.repo.ChangeStatus(ctx, g); err != nil {
return errors.Wrap(errChangeStatusGroupEvent, err)
}
return nil
}
func (es *eventHandler) removeGroupHandler(ctx context.Context, data map[string]interface{}) error {
g, err := decodeRemoveGroupEvent(data)
if err != nil {
return errors.Wrap(errRemoveGroupEvent, err)
}
if err := es.repo.Delete(ctx, g.ID); err != nil {
return errors.Wrap(errRemoveGroupEvent, err)
}
return nil
}
func (es *eventHandler) addParentGroupHandler(ctx context.Context, data map[string]interface{}) error {
id, parent, err := decodeAddParentGroupEvent(data)
if err != nil {
return errors.Wrap(errAddParentGroupEvent, err)
}
if err := es.repo.AssignParentGroup(ctx, parent, id); err != nil {
return errors.Wrap(errAddParentGroupEvent, err)
}
return nil
}
func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[string]interface{}) error {
id, err := decodeRemoveParentGroupEvent(data)
if err != nil {
return errors.Wrap(errRemoveParentGroupEvent, err)
}
g, err := es.repo.RetrieveByID(ctx, id)
if err != nil {
return errors.Wrap(errRemoveParentGroupEvent, err)
}
fmt.Println(g, g.Parent, g.ID)
if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil {
return errors.Wrap(errRemoveParentGroupEvent, err)
}
return nil
}
func (es *eventHandler) addChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
id, cids, err := decodeAddChildrenGroupEvent(data)
if err != nil {
return errors.Wrap(errAddChildrenGroupEvent, err)
}
if err := es.repo.AssignParentGroup(ctx, id, cids...); err != nil {
return errors.Wrap(errAddChildrenGroupEvent, err)
}
return nil
}
func (es *eventHandler) removeChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
id, cids, err := decodeRemoveChildrenGroupEvent(data)
if err != nil {
return errors.Wrap(errRemoveChildrenGroupEvent, err)
}
if err := es.repo.UnassignParentGroup(ctx, id, cids...); err != nil {
return errors.Wrap(errRemoveChildrenGroupEvent, err)
}
return nil
}
func (es *eventHandler) removeAllChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
id, err := decodeRemoveAllChildrenGroupEvent(data)
if err != nil {
return errors.Wrap(errRemoveAllChildrenGroupEvent, err)
}
if err := es.repo.UnassignAllChildrenGroups(ctx, id); err != nil && err != repoerr.ErrNotFound {
return errors.Wrap(errRemoveAllChildrenGroupEvent, err)
}
return nil
}
+6
View File
@@ -0,0 +1,6 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package events provides the events sourcing of groups to
// provide listing in clients and channels concept definitions needed to support
package events
+12 -3
View File
@@ -94,7 +94,7 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig)
return ErrEmptyTopic
}
nh := ps.natsHandler(cfg.Handler)
nh := ps.natsHandler(cfg.Handler, cfg.AckErr)
consumerConfig := jetstream.ConsumerConfig{
Name: formatConsumerName(cfg.Topic, cfg.ID),
@@ -104,6 +104,10 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig)
FilterSubject: cfg.Topic,
}
if cfg.Ordered {
consumerConfig.MaxAckPending = 1
}
switch cfg.DeliveryPolicy {
case messaging.DeliverNewPolicy:
consumerConfig.DeliverPolicy = jetstream.DeliverNewPolicy
@@ -140,17 +144,22 @@ func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
}
}
func (ps *pubsub) natsHandler(h messaging.MessageHandler) func(m jetstream.Msg) {
func (ps *pubsub) natsHandler(h messaging.MessageHandler, ackErr bool) func(m jetstream.Msg) {
return func(m jetstream.Msg) {
var msg messaging.Message
if err := proto.Unmarshal(m.Data(), &msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
return
}
if err := h.Handle(&msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to handle SuperMQ message: %s", err))
if ackErr {
if err := m.Ack(); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
}
}
return
}
if err := m.Ack(); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
+2
View File
@@ -39,6 +39,8 @@ type SubscriberConfig struct {
Topic string
Handler MessageHandler
DeliveryPolicy DeliveryPolicy
Ordered bool
AckErr bool
}
// Subscriber specifies message subscription API.
+9 -9
View File
@@ -29,24 +29,24 @@ func Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName string)
updated_at TIMESTAMP,
updated_by VARCHAR(254),
created_by VARCHAR(254),
CONSTRAINT unique_role_name_entity_id_constraint UNIQUE ( name, entity_id),
CONSTRAINT fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE
);`, rolesTableNamePrefix, entityTableName, entityIDColumnName),
CONSTRAINT %s_roles_unique_role_name_entity_id_constraint UNIQUE ( name, entity_id),
CONSTRAINT %s_roles_fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, entityTableName, entityIDColumnName),
fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_actions (
role_id VARCHAR(254) NOT NULL,
action VARCHAR(254) NOT NULL,
CONSTRAINT unique_domain_role_action_constraint UNIQUE ( role_id, action),
CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
CONSTRAINT %s_role_actions_unique_domain_role_action_constraint UNIQUE ( role_id, action),
CONSTRAINT %s_role_actions_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_members (
role_id VARCHAR(254) NOT NULL,
member_id VARCHAR(254) NOT NULL,
CONSTRAINT unique_role_member_constraint UNIQUE (role_id, member_id),
CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
CONSTRAINT %s_role_members_unique_role_member_constraint UNIQUE (role_id, member_id),
CONSTRAINT %s_role_members_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
},
Down: []string{
fmt.Sprintf(`DROP TABLE IF EXISTS %s_roles`, rolesTableNamePrefix),
+1 -1
View File
@@ -406,7 +406,7 @@ func (repo *Repository) RoleAddActions(ctx context.Context, role roles.Role, act
return []string{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
}
return repo.RoleListActions(ctx, role.ID)
return actions, nil
}
func (repo *Repository) RoleListActions(ctx context.Context, roleID string) ([]string, error) {
@@ -0,0 +1,146 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package consumer
import (
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/roles"
)
var (
errID = errors.New("missing or invalid 'id'")
errRoleID = errors.New("missing or invalid 'role_id'")
errName = errors.New("missing or invalid 'name'")
errEntityID = errors.New("missing or invalid 'entity_id'")
errActions = errors.New("missing or invalid 'actions'")
errMembers = errors.New("missing or invalid 'members'")
errCreatedAt = errors.New("failed to parse 'created_at' time")
errUpdatedAt = errors.New("failed to parse 'updated_at' time")
errNotString = errors.New("not string type")
errInvalidRoleProvision = errors.New("invalid 'role_provisions'")
errRoleProvision = errors.New("failed to convert role_provisions interface'")
errRoleProvisionMembers = errors.New("failed to convert role_provisions member interface'")
errRoleProvisionActions = errors.New("failed to convert role_provisions action interface'")
)
const (
layout = "2006-01-02T15:04:05.999999Z"
)
func ToRole(data map[string]interface{}) (roles.Role, error) {
var r roles.Role
id, ok := data["id"].(string)
if !ok {
return roles.Role{}, errID
}
r.ID = id
name, ok := data["name"].(string)
if !ok {
return roles.Role{}, errName
}
r.Name = name
eid, ok := data["entity_id"].(string)
if !ok {
return roles.Role{}, errEntityID
}
r.EntityID = eid
// Following fields of groups are allowed to be empty.
cat, ok := data["created_at"].(string)
if ok {
ct, err := time.Parse(layout, cat)
if err != nil {
return roles.Role{}, errors.Wrap(errCreatedAt, err)
}
r.CreatedAt = ct
}
cby, ok := data["created_by"].(string)
if ok {
r.CreatedBy = cby
}
uat, ok := data["updated_at"].(string)
if ok {
ut, err := time.Parse(layout, uat)
if err != nil {
return roles.Role{}, errors.Wrap(errUpdatedAt, err)
}
r.UpdatedAt = ut
}
uby, ok := data["updated_by"].(string)
if ok {
r.UpdatedBy = uby
}
return r, nil
}
func ToStrings(data []interface{}) ([]string, error) {
var strs []string
for _, i := range data {
str, ok := i.(string)
if !ok {
return []string{}, errNotString
}
strs = append(strs, str)
}
return strs, nil
}
func ToRoleProvision(data map[string]interface{}) (roles.RoleProvision, error) {
var rp roles.RoleProvision
r, err := ToRole(data)
if err != nil {
return roles.RoleProvision{}, err
}
rp.Role = r
// Following fields of groups are allowed to be empty.
opActs, ok := data["optional_actions"].([]interface{})
if ok {
a, err := ToStrings(opActs)
if err != nil {
return roles.RoleProvision{}, errors.Wrap(errRoleProvisionActions, err)
}
rp.OptionalActions = a
}
opMems, ok := data["optional_members"].([]interface{})
if ok {
m, err := ToStrings(opMems)
if err != nil {
return roles.RoleProvision{}, errors.Wrap(errRoleProvisionMembers, err)
}
rp.OptionalMembers = m
}
return rp, nil
}
func ToRoleProvisions(data []interface{}) ([]roles.RoleProvision, error) {
var rps []roles.RoleProvision
for _, d := range data {
irp, ok := d.(map[string]interface{})
if !ok {
return []roles.RoleProvision{}, errInvalidRoleProvision
}
rp, err := ToRoleProvision(irp)
if err != nil {
return []roles.RoleProvision{}, errors.Wrap(errRoleProvision, err)
}
rps = append(rps, rp)
}
return rps, nil
}
@@ -0,0 +1,188 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package consumer
import (
"context"
"fmt"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/roles"
)
const (
errAddEntityRoleEvent = "failed to consume %s add role event : %w"
errUpdateEntityRoleEvent = "failed to consume %s update role event : %w"
errRemoveEntityRoleEvent = "failed to consume %s remove role event : %w"
errAddEntityRoleActionsEvent = "failed to consume %s add role actions event : %w"
errRemoveEntityRoleActionsEvent = "failed to consume %s remove role actions event : %w"
errRemoveEntityRoleAllActionsEvent = "failed to consume %s remove role all actions event : %w"
errAddEntityRoleMembersEvent = "failed to consume %s add role members event : %w"
errRemoveEntityRoleMembersEvent = "failed to consume %s remove role members event : %w"
errRemoveEntityRoleAllMembersEvent = "failed to consume %s remove role all members event : %w"
)
type EventHandler struct {
entityType string
repo roles.Repository
}
func NewEventHandler(entityType string, repo roles.Repository) EventHandler {
return EventHandler{
entityType: entityType,
repo: repo,
}
}
func (es *EventHandler) AddEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
rps, err := ToRoleProvision(data)
if err != nil {
return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err)
}
if _, err := es.repo.AddRoles(ctx, []roles.RoleProvision{rps}); err != nil {
if !errors.Contains(err, repoerr.ErrConflict) {
return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err)
}
}
return nil
}
func (es *EventHandler) UpdateEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
ro, err := ToRole(data)
if err != nil {
return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err)
}
if _, err = es.repo.UpdateRole(ctx, ro); err != nil {
return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, errRoleID)
}
if err := es.repo.RemoveRoles(ctx, []string{id}); err != nil {
return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) AddEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID)
}
iacts, ok := data["actions"].([]interface{})
if !ok {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions)
}
acts, err := ToStrings(iacts)
if err != nil {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
}
if _, err := es.repo.RoleAddActions(ctx, roles.Role{ID: id}, acts); err != nil {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID)
}
iacts, ok := data["actions"].([]interface{})
if !ok {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions)
}
acts, err := ToStrings(iacts)
if err != nil {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
}
if err := es.repo.RoleRemoveActions(ctx, roles.Role{ID: id}, acts); err != nil {
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveAllEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, errRoleID)
}
if err := es.repo.RoleRemoveAllActions(ctx, roles.Role{ID: id}); err != nil {
return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) AddEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errRoleID)
}
imems, ok := data["members"].([]interface{})
if !ok {
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errMembers)
}
mems, err := ToStrings(imems)
if err != nil {
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err)
}
if _, err := es.repo.RoleAddMembers(ctx, roles.Role{ID: id}, mems); err != nil {
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errRoleID)
}
imems, ok := data["members"].([]interface{})
if !ok {
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errMembers)
}
mems, err := ToStrings(imems)
if err != nil {
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err)
}
if err := es.repo.RoleRemoveMembers(ctx, roles.Role{ID: id}, mems); err != nil {
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveAllEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
id, ok := data["role_id"].(string)
if !ok {
return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, errRoleID)
}
if err := es.repo.RoleRemoveAllMembers(ctx, roles.Role{ID: id}); err != nil {
return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, err)
}
return nil
}
func (es *EventHandler) RemoveMemberFromAllEntityHandler(ctx context.Context, data map[string]interface{}) error {
return nil
}
+4 -3
View File
@@ -24,9 +24,10 @@ type RoleManagerEventStore struct {
// events to event store.
func NewRoleManagerEventStore(svcName, operationPrefix string, svc roles.RoleManager, publisher events.Publisher) RoleManagerEventStore {
return RoleManagerEventStore{
svcName: svcName,
svc: svc,
Publisher: publisher,
svcName: svcName,
operationPrefix: operationPrefix,
svc: svc,
Publisher: publisher,
}
}
+68 -43
View File
@@ -392,9 +392,11 @@ func TestListChannels(t *testing.T) {
offset: offset,
total: total,
channelsPageMeta: channels.PageMetadata{
Offset: offset,
Limit: limit,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: limit,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
@@ -417,8 +419,11 @@ func TestListChannels(t *testing.T) {
offset: offset,
limit: limit,
channelsPageMeta: channels.PageMetadata{
Offset: offset,
Limit: limit,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: limit,
},
svcRes: channels.Page{},
authenticateErr: svcerr.ErrAuthentication,
@@ -426,16 +431,20 @@ func TestListChannels(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list channels with empty token",
token: "",
domainID: validID,
offset: offset,
limit: limit,
channelsPageMeta: channels.PageMetadata{},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
desc: "list channels with empty token",
token: "",
domainID: validID,
offset: offset,
limit: limit,
channelsPageMeta: channels.PageMetadata{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "list channels with zero limit",
@@ -444,9 +453,11 @@ func TestListChannels(t *testing.T) {
offset: offset,
limit: 0,
channelsPageMeta: channels.PageMetadata{
Offset: offset,
Limit: 10,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 10,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
@@ -464,16 +475,20 @@ func TestListChannels(t *testing.T) {
err: nil,
},
{
desc: "list channels with limit greater than max",
token: validToken,
domainID: domainID,
offset: offset,
limit: 110,
channelsPageMeta: channels.PageMetadata{},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
desc: "list channels with limit greater than max",
token: validToken,
domainID: domainID,
offset: offset,
limit: 110,
channelsPageMeta: channels.PageMetadata{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
},
{
desc: "list channels with level",
@@ -483,9 +498,11 @@ func TestListChannels(t *testing.T) {
limit: 1,
level: 1,
channelsPageMeta: channels.PageMetadata{
Offset: offset,
Limit: 1,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 1,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
@@ -510,10 +527,12 @@ func TestListChannels(t *testing.T) {
limit: 10,
metadata: sdk.Metadata{"name": "client_89"},
channelsPageMeta: channels.PageMetadata{
Offset: offset,
Limit: 10,
Permission: defPermission,
Metadata: clients.Metadata{"name": "client_89"},
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 10,
Metadata: clients.Metadata{"name": "client_89"},
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
@@ -539,11 +558,15 @@ func TestListChannels(t *testing.T) {
metadata: sdk.Metadata{
"test": make(chan int),
},
channelsPageMeta: channels.PageMetadata{},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
channelsPageMeta: channels.PageMetadata{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list channels with service response that can't be unmarshalled",
@@ -552,9 +575,11 @@ func TestListChannels(t *testing.T) {
offset: 0,
limit: 10,
channelsPageMeta: channels.PageMetadata{
Offset: 0,
Limit: 10,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 10,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
-17
View File
@@ -252,23 +252,6 @@ func (sdk mgSDK) DeleteClient(id, domainID, token string) errors.SDKError {
return sdkerr
}
func (sdk mgSDK) ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.clientsURL, fmt.Sprintf("%s/%s/%s/%s", domainID, usersEndpoint, userID, clientsEndpoint), pm)
if err != nil {
return ClientsPage{}, errors.NewSDKError(err)
}
_, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK)
if sdkerr != nil {
return ClientsPage{}, sdkerr
}
cp := ClientsPage{}
if err := json.Unmarshal(body, &cp); err != nil {
return ClientsPage{}, errors.NewSDKError(err)
}
return cp, nil
}
func (sdk mgSDK) CreateClientRole(id, domainID string, rq RoleReq, token string) (Role, errors.SDKError) {
return sdk.createRole(sdk.clientsURL, clientsEndpoint, id, domainID, rq, token)
}
+44 -275
View File
@@ -357,9 +357,11 @@ func TestListClients(t *testing.T) {
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 100,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
@@ -387,9 +389,11 @@ func TestListClients(t *testing.T) {
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 100,
},
svcRes: clients.ClientsPage{},
authenticateErr: svcerr.ErrAuthentication,
@@ -404,7 +408,11 @@ func TestListClients(t *testing.T) {
Offset: 0,
Limit: 1000,
},
svcReq: clients.Page{},
svcReq: clients.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
@@ -419,7 +427,11 @@ func TestListClients(t *testing.T) {
Limit: 100,
Name: strings.Repeat("a", 1025),
},
svcReq: clients.Page{},
svcReq: clients.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
@@ -435,10 +447,12 @@ func TestListClients(t *testing.T) {
Status: clients.DisabledStatus.String(),
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Status: clients.DisabledStatus,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 100,
Status: clients.DisabledStatus,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
@@ -468,10 +482,12 @@ func TestListClients(t *testing.T) {
Tag: "tag1",
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Tag: "tag1",
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 100,
Tag: "tag1",
},
svcRes: clients.ClientsPage{
Page: clients.Page{
@@ -502,7 +518,11 @@ func TestListClients(t *testing.T) {
"test": make(chan int),
},
},
svcReq: clients.Page{},
svcReq: clients.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
@@ -517,9 +537,11 @@ func TestListClients(t *testing.T) {
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 100,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
@@ -547,12 +569,12 @@ func TestListClients(t *testing.T) {
tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq).Return(tc.svcRes, tc.svcErr)
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Clients(tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq)
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.svcReq)
assert.True(t, ok)
}
svcCall.Unset()
@@ -1400,259 +1422,6 @@ func TestDeleteClient(t *testing.T) {
}
}
func TestListUserClients(t *testing.T) {
ts, tsvc, auth := setupClients()
defer ts.Close()
var sdkClients []sdk.Client
for i := 10; i < 100; i++ {
c := generateTestClient(t)
if i == 50 {
c.Status = clients.DisabledStatus.String()
c.Tags = []string{"tag1", "tag2"}
}
sdkClients = append(sdkClients, c)
}
conf := sdk.Config{
ClientsURL: ts.URL,
}
mgsdk := sdk.NewSDK(conf)
cases := []struct {
desc string
token string
session smqauthn.Session
userID string
domainID string
pageMeta sdk.PageMetadata
svcReq clients.Page
svcRes clients.ClientsPage
svcErr error
authenticateErr error
response sdk.ClientsPage
err errors.SDKError
}{
{
desc: "list user clients successfully",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
Offset: 0,
Limit: 100,
Total: uint64(len(sdkClients)),
},
Clients: convertClients(sdkClients...),
},
svcErr: nil,
response: sdk.ClientsPage{
PageRes: sdk.PageRes{
Limit: 100,
Total: uint64(len(sdkClients)),
},
Clients: sdkClients,
},
},
{
desc: "list user clients with an invalid token",
token: invalidToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
},
svcRes: clients.ClientsPage{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.ClientsPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list user clients with limit greater than max",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 1000,
},
svcReq: clients.Page{},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
},
{
desc: "list user clients with name size greater than max",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Name: strings.Repeat("a", 1025),
},
svcReq: clients.Page{},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
},
{
desc: "list user clients with status",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Status: clients.DisabledStatus.String(),
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Status: clients.DisabledStatus,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
Offset: 0,
Limit: 100,
Total: 1,
},
Clients: convertClients(sdkClients[50]),
},
svcErr: nil,
response: sdk.ClientsPage{
PageRes: sdk.PageRes{
Limit: 100,
Total: 1,
},
Clients: []sdk.Client{sdkClients[50]},
},
err: nil,
},
{
desc: "list user clients with tags",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Tag: "tag1",
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
Tag: "tag1",
},
svcRes: clients.ClientsPage{
Page: clients.Page{
Offset: 0,
Limit: 100,
Total: 1,
},
Clients: convertClients(sdkClients[50]),
},
svcErr: nil,
response: sdk.ClientsPage{
PageRes: sdk.PageRes{
Limit: 100,
Total: 1,
},
Clients: []sdk.Client{sdkClients[50]},
},
err: nil,
},
{
desc: "list user clients with invalid metadata",
token: validToken,
userID: validID,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Metadata: map[string]interface{}{
"test": make(chan int),
},
},
svcReq: clients.Page{},
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list user clients with response that can't be unmarshalled",
token: validToken,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
},
svcReq: clients.Page{
Offset: 0,
Limit: 100,
Permission: defPermission,
},
svcRes: clients.ClientsPage{
Page: clients.Page{
Offset: 0,
Limit: 100,
Total: 1,
},
Clients: []clients.Client{{
Name: sdkClients[0].Name,
Tags: sdkClients[0].Tags,
Credentials: clients.Credentials(sdkClients[0].Credentials),
Metadata: clients.Metadata{
"test": make(chan int),
},
}},
},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.ListUserClients(tc.userID, tc.domainID, tc.pageMeta, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq)
assert.True(t, ok)
}
svcCall.Unset()
authCall.Unset()
})
}
}
func TestSetClientParent(t *testing.T) {
ts, csvc, auth := setupClients()
defer ts.Close()
-61
View File
@@ -4393,67 +4393,6 @@ func (_c *SDK_ListDomainUsers_Call) RunAndReturn(run func(string, sdk.PageMetada
return _c
}
// ListUserClients provides a mock function with given fields: userID, domainID, pm, token
func (_m *SDK) ListUserClients(userID string, domainID string, pm sdk.PageMetadata, token string) (sdk.ClientsPage, errors.SDKError) {
ret := _m.Called(userID, domainID, pm, token)
if len(ret) == 0 {
panic("no return value specified for ListUserClients")
}
var r0 sdk.ClientsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)); ok {
return rf(userID, domainID, pm, token)
}
if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) sdk.ClientsPage); ok {
r0 = rf(userID, domainID, pm, token)
} else {
r0 = ret.Get(0).(sdk.ClientsPage)
}
if rf, ok := ret.Get(1).(func(string, string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(userID, domainID, pm, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
}
}
return r0, r1
}
// SDK_ListUserClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserClients'
type SDK_ListUserClients_Call struct {
*mock.Call
}
// ListUserClients is a helper method to define mock.On call
// - userID string
// - domainID string
// - pm sdk.PageMetadata
// - token string
func (_e *SDK_Expecter) ListUserClients(userID interface{}, domainID interface{}, pm interface{}, token interface{}) *SDK_ListUserClients_Call {
return &SDK_ListUserClients_Call{Call: _e.mock.On("ListUserClients", userID, domainID, pm, token)}
}
func (_c *SDK_ListUserClients_Call) Run(run func(userID string, domainID string, pm sdk.PageMetadata, token string)) *SDK_ListUserClients_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(string), args[2].(sdk.PageMetadata), args[3].(string))
})
return _c
}
func (_c *SDK_ListUserClients_Call) Return(_a0 sdk.ClientsPage, _a1 errors.SDKError) *SDK_ListUserClients_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *SDK_ListUserClients_Call) RunAndReturn(run func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)) *SDK_ListUserClients_Call {
_c.Call.Return(run)
return _c
}
// Members provides a mock function with given fields: groupID, domainID, pm, token
func (_m *SDK) Members(groupID string, domainID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(groupID, domainID, pm, token)
-11
View File
@@ -515,17 +515,6 @@ type SDK interface {
// fmt.Println(err)
RemoveClientParent(id, domainID, groupID, token string) errors.SDKError
// ListUserClients returns list of clients for the given user ID and filters.
//
// example:
// pm := sdk.PageMetadata{
// Offset: 0,
// Limit: 10,
// }
// clients, _ := sdk.ListUserClients("userID", "domainID", pm,"token")
// fmt.Println(clients)
ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError)
// CreateClientRole creates new client role and returns its id.
//
// example:
-1
View File
@@ -216,7 +216,6 @@ func convertChannel(g sdk.Channel) mgchannels.Channel {
UpdatedAt: g.UpdatedAt,
UpdatedBy: g.UpdatedBy,
Status: status,
Permissions: g.Permissions,
}
}