mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-2605 - Groups replication with groups events consumer & listing of things and channels (#2639)
Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
@@ -58,6 +58,7 @@ const (
|
||||
UserKey = "user"
|
||||
DomainKey = "domain"
|
||||
ChannelKey = "channel"
|
||||
ConnTypeKey = "connection_type"
|
||||
DefPermission = "read_permission"
|
||||
DefTotal = uint64(100)
|
||||
DefOffset = 0
|
||||
|
||||
+90
-42
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqclients "github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
@@ -51,58 +51,106 @@ func decodeCreateChannelsReq(_ context.Context, r *http.Request) (interface{}, e
|
||||
}
|
||||
|
||||
func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus)
|
||||
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
t, err := apiutil.ReadStringQuery(r, api.TagKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission)
|
||||
|
||||
tag, err := apiutil.ReadStringQuery(r, api.TagKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms)
|
||||
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
st, err := smqclients.ToStatus(s)
|
||||
status, err := clients.ToStatus(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
actions := []string{}
|
||||
|
||||
allActions = strings.TrimSpace(allActions)
|
||||
if allActions != "" {
|
||||
actions = strings.Split(allActions, ",")
|
||||
}
|
||||
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
req := listChannelsReq{
|
||||
status: st,
|
||||
offset: o,
|
||||
limit: l,
|
||||
metadata: m,
|
||||
name: n,
|
||||
tag: t,
|
||||
permission: p,
|
||||
listPerms: lp,
|
||||
userID: chi.URLParam(r, "userID"),
|
||||
id: id,
|
||||
name: name,
|
||||
tag: tag,
|
||||
status: status,
|
||||
metadata: meta,
|
||||
roleName: roleName,
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
accessType: accessType,
|
||||
order: order,
|
||||
dir: dir,
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
groupID: groupID,
|
||||
clientID: clientID,
|
||||
userID: userID,
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -104,15 +104,21 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
pm := channels.PageMetadata{
|
||||
Status: req.status,
|
||||
Offset: req.offset,
|
||||
Limit: req.limit,
|
||||
Name: req.name,
|
||||
Tag: req.tag,
|
||||
Permission: req.permission,
|
||||
Metadata: req.metadata,
|
||||
ListPerms: req.listPerms,
|
||||
Id: req.id,
|
||||
Offset: req.offset,
|
||||
Limit: req.limit,
|
||||
Name: req.name,
|
||||
Order: req.order,
|
||||
Dir: req.dir,
|
||||
Metadata: req.metadata,
|
||||
Tag: req.tag,
|
||||
Status: req.status,
|
||||
Group: req.groupID,
|
||||
Client: req.clientID,
|
||||
ConnectionType: req.connType,
|
||||
RoleName: req.roleName,
|
||||
RoleID: req.roleID,
|
||||
Actions: req.actions,
|
||||
AccessType: req.accessType,
|
||||
}
|
||||
page, err := svc.ListChannels(ctx, session, pm)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/channels"
|
||||
smqclients "github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
)
|
||||
|
||||
@@ -64,29 +64,29 @@ func (req viewChannelReq) validate() error {
|
||||
}
|
||||
|
||||
type listChannelsReq struct {
|
||||
status smqclients.Status
|
||||
offset uint64
|
||||
limit uint64
|
||||
name string
|
||||
tag string
|
||||
permission string
|
||||
visibility string
|
||||
status clients.Status
|
||||
metadata clients.Metadata
|
||||
roleName string
|
||||
roleID string
|
||||
actions []string
|
||||
accessType string
|
||||
order string
|
||||
dir string
|
||||
offset uint64
|
||||
limit uint64
|
||||
groupID string
|
||||
clientID string
|
||||
connType string
|
||||
userID string
|
||||
listPerms bool
|
||||
metadata smqclients.Metadata
|
||||
id string
|
||||
}
|
||||
|
||||
func (req listChannelsReq) validate() error {
|
||||
if req.limit > api.MaxLimitSize || req.limit < 1 {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
if req.visibility != "" &&
|
||||
req.visibility != api.AllVisibility &&
|
||||
req.visibility != api.MyVisibility &&
|
||||
req.visibility != api.SharedVisibility {
|
||||
return apiutil.ErrInvalidVisibilityType
|
||||
}
|
||||
|
||||
if len(req.name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
|
||||
@@ -174,14 +174,6 @@ func TestListChannelsReqValidation(t *testing.T) {
|
||||
},
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "invalid visibility",
|
||||
req: listChannelsReq{
|
||||
limit: 10,
|
||||
visibility: "invalid",
|
||||
},
|
||||
err: apiutil.ErrInvalidVisibilityType,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
|
||||
+41
-27
@@ -22,29 +22,43 @@ type Channel struct {
|
||||
ParentGroup string `json:"parent_group_id,omitempty"`
|
||||
Domain string `json:"domain_id,omitempty"`
|
||||
Metadata clients.Metadata `json:"metadata,omitempty"`
|
||||
CreatedBy string `json:"created_by,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
|
||||
Permissions []string `json:"permissions,omitempty"` // 1 for enabled, 0 for disabled
|
||||
Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
|
||||
// Extended
|
||||
ParentGroupPath string `json:"parent_group_path"`
|
||||
RoleID string `json:"role_id"`
|
||||
RoleName string `json:"role_name"`
|
||||
Actions []string `json:"actions"`
|
||||
AccessType string `json:"access_type"`
|
||||
AccessProviderId string `json:"access_provider_id"`
|
||||
AccessProviderRoleId string `json:"access_provider_role_id"`
|
||||
AccessProviderRoleName string `json:"access_provider_role_name"`
|
||||
AccessProviderRoleActions []string `json:"access_provider_role_actions"`
|
||||
}
|
||||
|
||||
type PageMetadata struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Metadata clients.Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Permission string `json:"permission,omitempty"`
|
||||
Status clients.Status `json:"status,omitempty"`
|
||||
IDs []string `json:"ids,omitempty"`
|
||||
ListPerms bool `json:"-"`
|
||||
ClientID string `json:"-"`
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata clients.Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Status clients.Status `json:"status,omitempty"`
|
||||
Group string `json:"group,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
ConnectionType string `json:"connection_type,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
IDs []string `json:"-"`
|
||||
}
|
||||
|
||||
// ChannelsPage contains page related metadata as well as list of channels that
|
||||
@@ -71,15 +85,15 @@ type AuthzReq struct {
|
||||
|
||||
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Service interface {
|
||||
// CreateChannels adds channels to the user identified by the provided key.
|
||||
// CreateChannels adds channels to the user.
|
||||
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error)
|
||||
|
||||
// ViewChannel retrieves data about the channel identified by the provided
|
||||
// ID, that belongs to the user identified by the provided key.
|
||||
// ID, that belongs to the user.
|
||||
ViewChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
|
||||
|
||||
// UpdateChannel updates the channel identified by the provided ID, that
|
||||
// belongs to the user identified by the provided key.
|
||||
// belongs to the user.
|
||||
UpdateChannel(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
|
||||
|
||||
// UpdateChannelTags updates the channel's tags.
|
||||
@@ -89,17 +103,14 @@ type Service interface {
|
||||
|
||||
DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
|
||||
|
||||
// ListChannels retrieves data about subset of channels that belongs to the
|
||||
// user identified by the provided key.
|
||||
// ListChannels retrieves data about subset of channels that belongs to the user.
|
||||
ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error)
|
||||
|
||||
// ListChannelsByClient retrieves data about subset of channels that have
|
||||
// specified client connected or not connected to them and belong to the user identified by
|
||||
// the provided key.
|
||||
ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm PageMetadata) (Page, error)
|
||||
// ListUserChannels retrieves data about subset of channels that belong to the specified user.
|
||||
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error)
|
||||
|
||||
// RemoveChannel removes the client identified by the provided ID, that
|
||||
// belongs to the user identified by the provided key.
|
||||
// belongs to the user.
|
||||
RemoveChannel(ctx context.Context, session authn.Session, id string) error
|
||||
|
||||
// Connect adds clients to the channels list of connected clients.
|
||||
@@ -131,6 +142,9 @@ type Repository interface {
|
||||
|
||||
ChangeStatus(ctx context.Context, channel Channel) (Channel, error)
|
||||
|
||||
// RetrieveUserChannels retrieves the channel of given domainID and userID.
|
||||
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm PageMetadata) (Page, error)
|
||||
|
||||
// RetrieveByID retrieves the channel having the provided identifier
|
||||
RetrieveByID(ctx context.Context, id string) (Channel, error)
|
||||
|
||||
|
||||
+27
-30
@@ -206,9 +206,6 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) {
|
||||
if lce.Tag != "" {
|
||||
val["tag"] = lce.Tag
|
||||
}
|
||||
if lce.Permission != "" {
|
||||
val["permission"] = lce.Permission
|
||||
}
|
||||
if lce.Status.String() != "" {
|
||||
val["status"] = lce.Status.String()
|
||||
}
|
||||
@@ -219,48 +216,48 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type listChannelByClientEvent struct {
|
||||
clientID string
|
||||
type listUserChannelsEvent struct {
|
||||
userID string
|
||||
channels.PageMetadata
|
||||
authn.Session
|
||||
}
|
||||
|
||||
func (lcte listChannelByClientEvent) Encode() (map[string]interface{}, error) {
|
||||
func (luce listUserChannelsEvent) Encode() (map[string]interface{}, error) {
|
||||
val := map[string]interface{}{
|
||||
"operation": channelList,
|
||||
"client_id": lcte.clientID,
|
||||
"total": lcte.Total,
|
||||
"offset": lcte.Offset,
|
||||
"limit": lcte.Limit,
|
||||
"domain": lcte.DomainID,
|
||||
"user_id": lcte.UserID,
|
||||
"token_type": lcte.Type.String(),
|
||||
"super_admin": lcte.SuperAdmin,
|
||||
"req_user_id": luce.userID,
|
||||
"total": luce.Total,
|
||||
"offset": luce.Offset,
|
||||
"limit": luce.Limit,
|
||||
"domain": luce.DomainID,
|
||||
"user_id": luce.UserID,
|
||||
"token_type": luce.Type.String(),
|
||||
"super_admin": luce.SuperAdmin,
|
||||
}
|
||||
|
||||
if lcte.Name != "" {
|
||||
val["name"] = lcte.Name
|
||||
if luce.Name != "" {
|
||||
val["name"] = luce.Name
|
||||
}
|
||||
if lcte.Order != "" {
|
||||
val["order"] = lcte.Order
|
||||
if luce.Order != "" {
|
||||
val["order"] = luce.Order
|
||||
}
|
||||
if lcte.Dir != "" {
|
||||
val["dir"] = lcte.Dir
|
||||
if luce.Dir != "" {
|
||||
val["dir"] = luce.Dir
|
||||
}
|
||||
if lcte.Metadata != nil {
|
||||
val["metadata"] = lcte.Metadata
|
||||
if luce.Metadata != nil {
|
||||
val["metadata"] = luce.Metadata
|
||||
}
|
||||
if lcte.Tag != "" {
|
||||
val["tag"] = lcte.Tag
|
||||
if luce.Domain != "" {
|
||||
val["domain"] = luce.Domain
|
||||
}
|
||||
if lcte.Permission != "" {
|
||||
val["permission"] = lcte.Permission
|
||||
if luce.Tag != "" {
|
||||
val["tag"] = luce.Tag
|
||||
}
|
||||
if lcte.Status.String() != "" {
|
||||
val["status"] = lcte.Status.String()
|
||||
if luce.Status.String() != "" {
|
||||
val["status"] = luce.Status.String()
|
||||
}
|
||||
if len(lcte.IDs) > 0 {
|
||||
val["ids"] = lcte.IDs
|
||||
if len(luce.IDs) > 0 {
|
||||
val["ids"] = luce.IDs
|
||||
}
|
||||
|
||||
return val, nil
|
||||
|
||||
@@ -126,13 +126,13 @@ func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, p
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
cp, err := es.svc.ListChannelsByClient(ctx, session, clientID, pm)
|
||||
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
cp, err := es.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
event := listChannelByClientEvent{
|
||||
clientID: clientID,
|
||||
event := listUserChannelsEvent{
|
||||
userID: userID,
|
||||
PageMetadata: pm,
|
||||
Session: session,
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
var (
|
||||
errView = errors.New("not authorized to view channel")
|
||||
errList = errors.New("not authorized to list user channels")
|
||||
errUpdate = errors.New("not authorized to update channel")
|
||||
errUpdateTags = errors.New("not authorized to update channel tags")
|
||||
errEnable = errors.New("not authorized to enable channel")
|
||||
@@ -164,13 +165,13 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
|
||||
}
|
||||
}
|
||||
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
|
||||
session.SuperAdmin = true
|
||||
}
|
||||
return am.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
if session.Type == authn.PersonalAccessToken {
|
||||
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
|
||||
UserID: session.UserID,
|
||||
@@ -184,7 +185,10 @@ func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, ses
|
||||
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
|
||||
}
|
||||
}
|
||||
return am.svc.ListChannelsByClient(ctx, session, clientID, pm)
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
|
||||
return channels.Page{}, errors.Wrap(err, errList)
|
||||
}
|
||||
return am.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
|
||||
@@ -82,11 +82,11 @@ func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Ses
|
||||
return lm.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (cp channels.Page, err error) {
|
||||
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (cp channels.Page, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("client_id", clientID),
|
||||
slog.String("user_id", userID),
|
||||
slog.Group("page",
|
||||
slog.Uint64("limit", pm.Limit),
|
||||
slog.Uint64("offset", pm.Offset),
|
||||
@@ -95,12 +95,12 @@ func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session a
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List channels by client failed", args...)
|
||||
lm.logger.Warn("List user channels failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List channels by client completed successfully", args...)
|
||||
lm.logger.Info("List user channels completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListChannelsByClient(ctx, session, clientID, pm)
|
||||
return lm.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
|
||||
|
||||
@@ -58,12 +58,12 @@ func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Ses
|
||||
return ms.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_channels_by_client").Add(1)
|
||||
ms.latency.With("method", "list_channels_by_client").Observe(time.Since(begin).Seconds())
|
||||
ms.counter.With("method", "list_user_channels").Add(1)
|
||||
ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListChannelsByClient(ctx, session, clientID, pm)
|
||||
return ms.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
|
||||
@@ -529,6 +529,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveUserChannels provides a mock function with given fields: ctx, domainID, userID, pm
|
||||
func (_m *Repository) RetrieveUserChannels(ctx context.Context, domainID string, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
ret := _m.Called(ctx, domainID, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveUserChannels")
|
||||
}
|
||||
|
||||
var r0 channels.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) (channels.Page, error)); ok {
|
||||
return rf(ctx, domainID, userID, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) channels.Page); ok {
|
||||
r0 = rf(ctx, domainID, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(channels.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, channels.PageMetadata) error); ok {
|
||||
r1 = rf(ctx, domainID, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RoleAddActions provides a mock function with given fields: ctx, role, actions
|
||||
func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) {
|
||||
ret := _m.Called(ctx, role, actions)
|
||||
|
||||
@@ -246,27 +246,27 @@ func (_m *Service) ListChannels(ctx context.Context, session authn.Session, pm c
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListChannelsByClient provides a mock function with given fields: ctx, session, id, pm
|
||||
func (_m *Service) ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
ret := _m.Called(ctx, session, id, pm)
|
||||
// ListUserChannels provides a mock function with given fields: ctx, session, userID, pm
|
||||
func (_m *Service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
ret := _m.Called(ctx, session, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListChannelsByClient")
|
||||
panic("no return value specified for ListUserChannels")
|
||||
}
|
||||
|
||||
var r0 channels.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) (channels.Page, error)); ok {
|
||||
return rf(ctx, session, id, pm)
|
||||
return rf(ctx, session, userID, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) channels.Page); ok {
|
||||
r0 = rf(ctx, session, id, pm)
|
||||
r0 = rf(ctx, session, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(channels.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, channels.PageMetadata) error); ok {
|
||||
r1 = rf(ctx, session, id, pm)
|
||||
r1 = rf(ctx, session, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
+404
-50
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/postgres"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -152,7 +153,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
|
||||
query = applyOrdering(query, pm)
|
||||
|
||||
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
|
||||
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
|
||||
c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s LIMIT :limit OFFSET :offset;`, query)
|
||||
|
||||
dbPage, err := toDBChannelsPage(pm)
|
||||
if err != nil {
|
||||
@@ -196,6 +197,303 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *channelRepository) RetrieveUserChannels(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
return repo.retrieveClients(ctx, domainID, userID, pm)
|
||||
}
|
||||
|
||||
func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
pageQuery, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return channels.Page{}, err
|
||||
}
|
||||
|
||||
bq := repo.userChannelsBaseQuery(domainID, userID)
|
||||
|
||||
connJoinQuery := ""
|
||||
if pm.Client != "" {
|
||||
connJoinQuery = "JOIN connection conn ON conn.channel_id = c.id"
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
FROM
|
||||
final_channels c
|
||||
%s
|
||||
%s
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
|
||||
q = applyOrdering(q, pm)
|
||||
|
||||
dbPage, err := toDBChannelsPage(pm)
|
||||
if err != nil {
|
||||
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []channels.Channel
|
||||
for rows.Next() {
|
||||
dbc := dbChannel{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := toChannel(dbc)
|
||||
if err != nil {
|
||||
return channels.Page{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
chJoinQuery := ""
|
||||
if pm.Client != "" {
|
||||
chJoinQuery = "JOIN connection conn ON conn.channel_id = c.id"
|
||||
}
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
FROM
|
||||
final_channels c
|
||||
%s
|
||||
%s
|
||||
) AS subquery;
|
||||
`, bq, chJoinQuery, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := channels.Page{
|
||||
Channels: items,
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *channelRepository) userChannelsBaseQuery(domainID, userID string) string {
|
||||
return fmt.Sprintf(`
|
||||
WITH direct_channels AS (
|
||||
select
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
text2ltree('') as parent_group_path,
|
||||
cr.id AS role_id,
|
||||
cr."name" AS role_name,
|
||||
array_agg(cra."action") AS actions,
|
||||
'direct' as access_type,
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
FROM
|
||||
channels_role_members crm
|
||||
JOIN
|
||||
channels_role_actions cra ON cra.role_id = crm.role_id
|
||||
JOIN
|
||||
channels_roles cr ON cr.id = crm.role_id
|
||||
JOIN
|
||||
channels c ON c.id = cr.entity_id
|
||||
WHERE
|
||||
crm.member_id = '%s'
|
||||
AND c.domain_id = '%s'
|
||||
GROUP BY
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
|
||||
),
|
||||
direct_groups AS (
|
||||
SELECT
|
||||
g.*,
|
||||
gr.entity_id AS entity_id,
|
||||
grm.member_id AS member_id,
|
||||
gr.id AS role_id,
|
||||
gr."name" AS role_name,
|
||||
array_agg(gra."action") AS actions
|
||||
FROM
|
||||
groups_role_members grm
|
||||
JOIN
|
||||
groups_role_actions gra ON gra.role_id = grm.role_id
|
||||
JOIN
|
||||
groups_roles gr ON gr.id = grm.role_id
|
||||
JOIN
|
||||
"groups" g ON g.id = gr.entity_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
direct_groups_with_subgroup AS (
|
||||
SELECT
|
||||
*
|
||||
FROM direct_groups
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM unnest(direct_groups.actions) AS action
|
||||
WHERE action LIKE 'subgroup_%%'
|
||||
)
|
||||
),
|
||||
indirect_child_groups AS (
|
||||
SELECT
|
||||
DISTINCT indirect_child_groups.id as child_id,
|
||||
indirect_child_groups.*,
|
||||
dgws.id as access_provider_id,
|
||||
dgws.role_id as access_provider_role_id,
|
||||
dgws.role_name as access_provider_role_name,
|
||||
dgws.actions as access_provider_role_actions
|
||||
FROM
|
||||
direct_groups_with_subgroup dgws
|
||||
JOIN
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path
|
||||
WHERE
|
||||
indirect_child_groups.domain_id = '%s'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM direct_groups_with_subgroup dgws
|
||||
WHERE dgws.id = indirect_child_groups.id
|
||||
)
|
||||
),
|
||||
final_groups AS (
|
||||
SELECT
|
||||
id,
|
||||
parent_id,
|
||||
domain_id,
|
||||
"name",
|
||||
description,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
updated_by,
|
||||
status,
|
||||
"path",
|
||||
role_id,
|
||||
role_name,
|
||||
actions,
|
||||
'direct_group' AS access_type,
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
FROM
|
||||
direct_groups
|
||||
UNION
|
||||
SELECT
|
||||
id,
|
||||
parent_id,
|
||||
domain_id,
|
||||
"name",
|
||||
description,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
updated_by,
|
||||
status,
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
'indirect_group' AS access_type,
|
||||
access_provider_id,
|
||||
access_provider_role_id,
|
||||
access_provider_role_name,
|
||||
access_provider_role_actions
|
||||
FROM
|
||||
indirect_child_groups
|
||||
),
|
||||
final_channels AS (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_by,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
g.path AS parent_group_path,
|
||||
g.role_id,
|
||||
g.role_name,
|
||||
g.actions,
|
||||
g.access_type,
|
||||
g.access_provider_id,
|
||||
g.access_provider_role_id,
|
||||
g.access_provider_role_name,
|
||||
g.access_provider_role_actions
|
||||
FROM
|
||||
final_groups g
|
||||
JOIN
|
||||
channels c ON c.parent_group_id = g.id
|
||||
WHERE
|
||||
c.id NOT IN (SELECT id FROM direct_channels)
|
||||
UNION
|
||||
SELECT * FROM direct_channels
|
||||
)
|
||||
`, userID, domainID, userID, domainID, domainID)
|
||||
}
|
||||
|
||||
func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
|
||||
q := "DELETE FROM channels AS c WHERE c.id = ANY(:channel_ids) ;"
|
||||
params := map[string]interface{}{
|
||||
@@ -361,7 +659,7 @@ func (cr *channelRepository) RemoveChannelConnections(ctx context.Context, chann
|
||||
|
||||
func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, parentGroupID string) ([]channels.Channel, error) {
|
||||
query := `SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
|
||||
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;`
|
||||
c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;`
|
||||
|
||||
rows, err := cr.db.NamedQueryContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)})
|
||||
if err != nil {
|
||||
@@ -420,17 +718,26 @@ func (cr *channelRepository) update(ctx context.Context, ch channels.Channel, qu
|
||||
}
|
||||
|
||||
type dbChannel struct {
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name,omitempty"`
|
||||
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Domain string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
CreatedAt time.Time `db:"created_at,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
Status clients.Status `db:"status,omitempty"`
|
||||
Role *clients.Role `db:"role,omitempty"`
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name,omitempty"`
|
||||
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Domain string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
CreatedBy *string `db:"created_by,omitempty"`
|
||||
CreatedAt time.Time `db:"created_at,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
Status clients.Status `db:"status,omitempty"`
|
||||
ParentGroupPath string `db:"parent_group_path,omitempty"`
|
||||
RoleID string `db:"role_id,omitempty"`
|
||||
RoleName string `db:"role_name,omitempty"`
|
||||
Actions pq.StringArray `db:"actions,omitempty"`
|
||||
AccessType string `db:"access_type,omitempty"`
|
||||
AccessProviderId string `db:"access_provider_id,omitempty"`
|
||||
AccessProviderRoleId string `db:"access_provider_role_id,omitempty"`
|
||||
AccessProviderRoleName string `db:"access_provider_role_name,omitempty"`
|
||||
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"`
|
||||
}
|
||||
|
||||
func toDBChannel(ch channels.Channel) (dbChannel, error) {
|
||||
@@ -446,6 +753,10 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) {
|
||||
if err := tags.Set(ch.Tags); err != nil {
|
||||
return dbChannel{}, err
|
||||
}
|
||||
var createdBy *string
|
||||
if ch.CreatedBy != "" {
|
||||
createdBy = &ch.CreatedBy
|
||||
}
|
||||
var updatedBy *string
|
||||
if ch.UpdatedBy != "" {
|
||||
updatedBy = &ch.UpdatedBy
|
||||
@@ -461,6 +772,7 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) {
|
||||
Domain: ch.Domain,
|
||||
Tags: tags,
|
||||
Metadata: data,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: ch.CreatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
UpdatedBy: updatedBy,
|
||||
@@ -497,6 +809,10 @@ func toChannel(ch dbChannel) (channels.Channel, error) {
|
||||
for _, e := range ch.Tags.Elements {
|
||||
tags = append(tags, e.String)
|
||||
}
|
||||
var createdBy string
|
||||
if ch.CreatedBy != nil {
|
||||
createdBy = *ch.CreatedBy
|
||||
}
|
||||
var updatedBy string
|
||||
if ch.UpdatedBy != nil {
|
||||
updatedBy = *ch.UpdatedBy
|
||||
@@ -507,16 +823,26 @@ func toChannel(ch dbChannel) (channels.Channel, error) {
|
||||
}
|
||||
|
||||
newCh := channels.Channel{
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Tags: tags,
|
||||
Domain: ch.Domain,
|
||||
ParentGroup: toString(ch.ParentGroup),
|
||||
Metadata: metadata,
|
||||
CreatedAt: ch.CreatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
UpdatedBy: updatedBy,
|
||||
Status: ch.Status,
|
||||
ID: ch.ID,
|
||||
Name: ch.Name,
|
||||
Tags: tags,
|
||||
Domain: ch.Domain,
|
||||
ParentGroup: toString(ch.ParentGroup),
|
||||
Metadata: metadata,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: ch.CreatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
UpdatedBy: updatedBy,
|
||||
Status: ch.Status,
|
||||
ParentGroupPath: ch.ParentGroupPath,
|
||||
RoleID: ch.RoleID,
|
||||
RoleName: ch.RoleName,
|
||||
Actions: ch.Actions,
|
||||
AccessType: ch.AccessType,
|
||||
AccessProviderId: ch.AccessProviderId,
|
||||
AccessProviderRoleId: ch.AccessProviderRoleId,
|
||||
AccessProviderRoleName: ch.AccessProviderRoleName,
|
||||
AccessProviderRoleActions: ch.AccessProviderRoleActions,
|
||||
}
|
||||
|
||||
return newCh, nil
|
||||
@@ -533,9 +859,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
|
||||
query = append(query, "c.name ILIKE '%' || :name || '%'")
|
||||
}
|
||||
|
||||
if pm.ClientID != "" {
|
||||
query = append(query, "conn.client_id = :client_id")
|
||||
}
|
||||
if pm.Id != "" {
|
||||
query = append(query, "c.id ILIKE '%' || :id || '%'")
|
||||
}
|
||||
@@ -543,12 +866,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
|
||||
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
|
||||
}
|
||||
|
||||
// If there are search params presents, use search and ignore other options.
|
||||
// Always combine role with search params, so len(query) > 1.
|
||||
if len(query) > 1 {
|
||||
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil
|
||||
}
|
||||
|
||||
if mq != "" {
|
||||
query = append(query, mq)
|
||||
}
|
||||
@@ -562,6 +879,31 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
|
||||
if pm.Domain != "" {
|
||||
query = append(query, "c.domain_id = :domain_id")
|
||||
}
|
||||
if pm.Group != "" {
|
||||
query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ")
|
||||
}
|
||||
if pm.Client != "" {
|
||||
query = append(query, "conn.client_id = :client_id ")
|
||||
if pm.ConnectionType != "" {
|
||||
query = append(query, "conn.type = :conn_type ")
|
||||
}
|
||||
}
|
||||
if pm.AccessType != "" {
|
||||
query = append(query, "c.access_type = :access_type")
|
||||
}
|
||||
if pm.RoleID != "" {
|
||||
query = append(query, "c.role_id = :role_id")
|
||||
}
|
||||
if pm.RoleName != "" {
|
||||
query = append(query, "c.role_name = :role_name")
|
||||
}
|
||||
if len(pm.Actions) != 0 {
|
||||
query = append(query, "c.actions @> :actions")
|
||||
}
|
||||
if len(pm.Metadata) > 0 {
|
||||
query = append(query, "c.metadata @> :metadata")
|
||||
}
|
||||
|
||||
var emq string
|
||||
if len(query) > 0 {
|
||||
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
|
||||
@@ -586,28 +928,40 @@ func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) {
|
||||
return dbChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return dbChannelsPage{
|
||||
Name: pm.Name,
|
||||
Id: pm.Id,
|
||||
Metadata: data,
|
||||
Domain: pm.Domain,
|
||||
Total: pm.Total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
Status: pm.Status,
|
||||
Tag: pm.Tag,
|
||||
Limit: pm.Limit,
|
||||
Offset: pm.Offset,
|
||||
Name: pm.Name,
|
||||
Id: pm.Id,
|
||||
Domain: pm.Domain,
|
||||
Metadata: data,
|
||||
Tag: pm.Tag,
|
||||
Status: pm.Status,
|
||||
GroupID: pm.Group,
|
||||
ClientID: pm.Client,
|
||||
ConnType: pm.ConnectionType,
|
||||
RoleName: pm.RoleName,
|
||||
RoleID: pm.RoleID,
|
||||
Actions: pm.Actions,
|
||||
AccessType: pm.AccessType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dbChannelsPage struct {
|
||||
Total uint64 `db:"total"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Name string `db:"name"`
|
||||
Id string `db:"id"`
|
||||
Domain string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tag string `db:"tag"`
|
||||
Status clients.Status `db:"status"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Name string `db:"name"`
|
||||
Id string `db:"id"`
|
||||
Domain string `db:"domain_id"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tag string `db:"tag"`
|
||||
Status clients.Status `db:"status"`
|
||||
GroupID string `db:"group_id"`
|
||||
ClientID string `db:"client_id"`
|
||||
ConnType string `db:"type"`
|
||||
RoleName string `db:"role_name"`
|
||||
RoleID string `db:"role_id"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
AccessType string `db:"access_type"`
|
||||
}
|
||||
|
||||
type dbConnection struct {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
@@ -56,5 +57,13 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
},
|
||||
}
|
||||
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
|
||||
|
||||
groupsMigration, err := gpostgres.Migration()
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, err
|
||||
}
|
||||
|
||||
channelsMigration.Migrations = append(channelsMigration.Migrations, groupsMigration.Migrations...)
|
||||
|
||||
return channelsMigration, nil
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ const (
|
||||
// External Permission
|
||||
// Domains.
|
||||
domainCreateChannelPermission = "channel_create_permission"
|
||||
domainListChanelPermission = "list_channels_permission"
|
||||
domainListChanelPermission = "channel_read_permission"
|
||||
// Groups.
|
||||
groupSetChildChannelPermission = "channel_create_permission"
|
||||
groupRemoveChildChannelPermission = "channel_create_permission"
|
||||
|
||||
+12
-68
@@ -21,7 +21,6 @@ import (
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -183,50 +182,30 @@ func (svc service) ViewChannel(ctx context.Context, session authn.Session, id st
|
||||
}
|
||||
|
||||
func (svc service) ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error) {
|
||||
var ids []string
|
||||
var err error
|
||||
|
||||
switch session.SuperAdmin {
|
||||
case true:
|
||||
pm.Domain = session.DomainID
|
||||
default:
|
||||
ids, err = svc.listChannelIDs(ctx, session.DomainUserID, pm.Permission)
|
||||
cp, err := svc.repo.RetrieveAll(ctx, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return cp, nil
|
||||
default:
|
||||
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, session.UserID, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
if len(ids) == 0 && pm.Domain == "" {
|
||||
return Page{}, nil
|
||||
}
|
||||
pm.IDs = ids
|
||||
}
|
||||
|
||||
cp, err := svc.repo.RetrieveAll(ctx, pm)
|
||||
func (svc service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error) {
|
||||
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, userID, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if pm.ListPerms && len(cp.Channels) > 0 {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
for i := range cp.Channels {
|
||||
// Copying loop variable "i" to avoid "loop variable captured by func literal"
|
||||
iter := i
|
||||
g.Go(func() error {
|
||||
return svc.retrievePermissions(ctx, session.DomainUserID, &cp.Channels[iter])
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return Page{}, err
|
||||
}
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (svc service) ListChannelsByClient(ctx context.Context, session authn.Session, clID string, pm PageMetadata) (Page, error) {
|
||||
return Page{}, nil
|
||||
}
|
||||
|
||||
func (svc service) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
ok, err := svc.repo.DoesChannelHaveConnections(ctx, id)
|
||||
if err != nil {
|
||||
@@ -493,41 +472,6 @@ func (svc service) RemoveParentGroup(ctx context.Context, session authn.Session,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) listChannelIDs(ctx context.Context, userID, permission string) ([]string, error) {
|
||||
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Permission: permission,
|
||||
ObjectType: policies.ChannelType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
return tids.Policies, nil
|
||||
}
|
||||
|
||||
func (svc service) retrievePermissions(ctx context.Context, userID string, channel *Channel) error {
|
||||
permissions, err := svc.listUserClientPermission(ctx, userID, channel.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel.Permissions = permissions
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) {
|
||||
lp, err := svc.policy.ListPermissions(ctx, policies.Policy{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Object: clientID,
|
||||
ObjectType: policies.ChannelType,
|
||||
}, []string{})
|
||||
if err != nil {
|
||||
return []string{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
return lp, nil
|
||||
}
|
||||
|
||||
func (svc service) changeChannelStatus(ctx context.Context, userID string, channel Channel) (Channel, error) {
|
||||
dbchannel, err := svc.repo.RetrieveByID(ctx, channel.ID)
|
||||
if err != nil {
|
||||
|
||||
+148
-187
@@ -21,6 +21,7 @@ import (
|
||||
gpmocks "github.com/absmach/supermq/groups/mocks"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/connections"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
@@ -459,220 +460,180 @@ func TestDisableChannel(t *testing.T) {
|
||||
func TestListChannels(t *testing.T) {
|
||||
svc := newService(t)
|
||||
|
||||
channelWithPerms := validChannel
|
||||
channelWithPerms.Permissions = []string{policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission}
|
||||
adminID := testsutil.GenerateUUID(t)
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
nonAdminID := testsutil.GenerateUUID(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta channels.PageMetadata
|
||||
listAllObjectsRes policysvc.PolicyPage
|
||||
listAllObjectsErr error
|
||||
retrieveAllRes channels.Page
|
||||
retrieveAllErr error
|
||||
listPermissionsRes policysvc.Permissions
|
||||
listPermissionsErr error
|
||||
resp channels.Page
|
||||
err error
|
||||
desc string
|
||||
userKind string
|
||||
session smqauthn.Session
|
||||
page channels.PageMetadata
|
||||
retrieveAllResponse channels.Page
|
||||
response channels.Page
|
||||
id string
|
||||
size uint64
|
||||
listObjectsErr error
|
||||
retrieveAllErr error
|
||||
listPermissionsErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list channesls as admin successfully",
|
||||
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
|
||||
pageMeta: channels.PageMetadata{
|
||||
Domain: validID,
|
||||
desc: "list all channels successfully as non admin",
|
||||
userKind: "non-admin",
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
retrieveAllResponse: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Channels: []channels.Channel{validChannel, validChannel},
|
||||
},
|
||||
resp: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
response: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Channels: []channels.Channel{validChannel, validChannel},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list channels as admin with list perms successfully",
|
||||
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
|
||||
pageMeta: channels.PageMetadata{
|
||||
Domain: validID,
|
||||
ListPerms: true,
|
||||
desc: "list all channels as non admin with failed to retrieve all",
|
||||
userKind: "non-admin",
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
listPermissionsRes: policysvc.Permissions{
|
||||
policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission,
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
resp: channels.Page{
|
||||
Channels: []channels.Channel{channelWithPerms},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
retrieveAllResponse: channels.Page{},
|
||||
response: channels.Page{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list channels as admin with failed to retrieve all",
|
||||
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
|
||||
pageMeta: channels.PageMetadata{
|
||||
Domain: validID,
|
||||
desc: "list all channels as non admin with failed super admin",
|
||||
userKind: "non-admin",
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
retrieveAllRes: channels.Page{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list channels as admin with failed to list permissions",
|
||||
session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
|
||||
pageMeta: channels.PageMetadata{
|
||||
Domain: validID,
|
||||
ListPerms: true,
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
listPermissionsRes: policysvc.Permissions{},
|
||||
listPermissionsErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "list channels as admin with no domain id",
|
||||
session: authn.Session{UserID: validID, SuperAdmin: true},
|
||||
pageMeta: channels.PageMetadata{},
|
||||
response: channels.Page{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list channels as user successfully",
|
||||
session: validSession,
|
||||
pageMeta: channels.PageMetadata{
|
||||
Permission: policysvc.ViewPermission,
|
||||
IDs: []string{validChannel.ID},
|
||||
desc: "list all channels as non admin with failed to list objects",
|
||||
userKind: "non-admin",
|
||||
id: nonAdminID,
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
listAllObjectsRes: policysvc.PolicyPage{
|
||||
Policies: []string{validChannel.ID},
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
resp: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list channels as user with failed to list all objects",
|
||||
session: validSession,
|
||||
pageMeta: channels.PageMetadata{
|
||||
Permission: policysvc.ViewPermission,
|
||||
IDs: []string{validChannel.ID},
|
||||
},
|
||||
listAllObjectsErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "list channels as user with list permissions successfully",
|
||||
session: validSession,
|
||||
pageMeta: channels.PageMetadata{
|
||||
Permission: policysvc.ViewPermission,
|
||||
IDs: []string{validChannel.ID},
|
||||
ListPerms: true,
|
||||
},
|
||||
listAllObjectsRes: policysvc.PolicyPage{
|
||||
Policies: []string{validChannel.ID},
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
listPermissionsRes: policysvc.Permissions{
|
||||
policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission,
|
||||
},
|
||||
resp: channels.Page{
|
||||
Channels: []channels.Channel{channelWithPerms},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list channels as user with list permissions and failed to list permissions",
|
||||
session: validSession,
|
||||
pageMeta: channels.PageMetadata{
|
||||
Permission: policysvc.ViewPermission,
|
||||
IDs: []string{validChannel.ID},
|
||||
ListPerms: true,
|
||||
},
|
||||
listAllObjectsRes: policysvc.PolicyPage{
|
||||
Policies: []string{validChannel.ID},
|
||||
},
|
||||
retrieveAllRes: channels.Page{
|
||||
Channels: []channels.Channel{validChannel},
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 1,
|
||||
},
|
||||
},
|
||||
listPermissionsRes: policysvc.Permissions{},
|
||||
listPermissionsErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "list channels as user with failed to retrieve all",
|
||||
session: validSession,
|
||||
pageMeta: channels.PageMetadata{
|
||||
Permission: policysvc.ViewPermission,
|
||||
IDs: []string{validChannel.ID},
|
||||
},
|
||||
listAllObjectsRes: policysvc.PolicyPage{
|
||||
Policies: []string{validChannel.ID},
|
||||
},
|
||||
retrieveAllRes: channels.Page{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: repoerr.ErrNotFound,
|
||||
response: channels.Page{},
|
||||
listObjectsErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
policyCall := policies.On("ListAllObjects", context.Background(), policysvc.Policy{
|
||||
SubjectType: policysvc.UserType,
|
||||
Subject: validID,
|
||||
Permission: policysvc.ViewPermission,
|
||||
ObjectType: policysvc.ChannelType,
|
||||
}).Return(tc.listAllObjectsRes, tc.listAllObjectsErr)
|
||||
repoCall := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.retrieveAllRes, tc.retrieveAllErr)
|
||||
policyCall1 := policies.On("ListPermissions", mock.Anything, policysvc.Policy{
|
||||
SubjectType: policysvc.UserType,
|
||||
Subject: validID,
|
||||
Object: validChannel.ID,
|
||||
ObjectType: policysvc.ChannelType,
|
||||
}, []string{}).Return(tc.listPermissionsRes, tc.listPermissionsErr)
|
||||
got, err := svc.ListChannels(context.Background(), tc.session, tc.pageMeta)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
|
||||
assert.Equal(t, tc.resp, got)
|
||||
policyCall.Unset()
|
||||
repoCall.Unset()
|
||||
policyCall1.Unset()
|
||||
})
|
||||
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
retrieveUserClientsCall := repo.On("RetrieveUserChannels", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListChannels(context.Background(), tc.session, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
retrieveAllCall.Unset()
|
||||
retrieveUserClientsCall.Unset()
|
||||
}
|
||||
|
||||
cases2 := []struct {
|
||||
desc string
|
||||
userKind string
|
||||
session smqauthn.Session
|
||||
page channels.PageMetadata
|
||||
retrieveAllResponse channels.Page
|
||||
response channels.Page
|
||||
id string
|
||||
size uint64
|
||||
listObjectsErr error
|
||||
retrieveAllErr error
|
||||
listPermissionsErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list all clients as admin successfully",
|
||||
userKind: "admin",
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
retrieveAllResponse: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Channels: []channels.Channel{validChannel, validChannel},
|
||||
},
|
||||
response: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Channels: []channels.Channel{validChannel, validChannel},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as admin with failed to retrieve all",
|
||||
userKind: "admin",
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
retrieveAllResponse: channels.Page{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as admin with failed to list clients",
|
||||
userKind: "admin",
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
retrieveAllResponse: channels.Page{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases2 {
|
||||
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListChannels(context.Background(), tc.session, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
retrieveAllCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,10 +50,10 @@ func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Ses
|
||||
return tm.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_channels")
|
||||
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_user_channels")
|
||||
defer span.End()
|
||||
return tm.svc.ListChannelsByClient(ctx, session, clientID, pm)
|
||||
return tm.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
// UpdateChannel traces the "UpdateChannel" operation of the wrapped policies.Service.
|
||||
|
||||
@@ -370,32 +370,6 @@ var cmdUsers = []cobra.Command{
|
||||
logOKCmd(*cmd)
|
||||
},
|
||||
},
|
||||
{
|
||||
Use: "clients <user_id> <domain_id> <user_auth_token>",
|
||||
Short: "List clients",
|
||||
Long: "List clients of user\n" +
|
||||
"Usage:\n" +
|
||||
"\tsupermq-cli users clients <user_id> <user_auth_token>\n",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(args) != 3 {
|
||||
logUsageCmd(*cmd, cmd.Use)
|
||||
return
|
||||
}
|
||||
|
||||
pm := smqsdk.PageMetadata{
|
||||
Offset: Offset,
|
||||
Limit: Limit,
|
||||
}
|
||||
|
||||
tp, err := sdk.ListUserClients(args[0], args[1], pm, args[2])
|
||||
if err != nil {
|
||||
logErrorCmd(*cmd, err)
|
||||
return
|
||||
}
|
||||
|
||||
logJSONCmd(*cmd, tp)
|
||||
},
|
||||
},
|
||||
{
|
||||
Use: "search <query> <user_auth_token>",
|
||||
Short: "Search users",
|
||||
|
||||
+95
-41
@@ -27,58 +27,112 @@ func decodeViewClient(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
}
|
||||
|
||||
func decodeListClients(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus)
|
||||
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
t, err := apiutil.ReadStringQuery(r, api.TagKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission)
|
||||
|
||||
tag, err := apiutil.ReadStringQuery(r, api.TagKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms)
|
||||
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
st, err := clients.ToStatus(s)
|
||||
status, err := clients.ToStatus(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
actions := []string{}
|
||||
|
||||
allActions = strings.TrimSpace(allActions)
|
||||
if allActions != "" {
|
||||
actions = strings.Split(allActions, ",")
|
||||
}
|
||||
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
channelID, err := apiutil.ReadStringQuery(r, api.ChannelKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
connType, err := apiutil.ReadStringQuery(r, api.ConnTypeKey, "")
|
||||
if err != nil {
|
||||
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
req := listClientsReq{
|
||||
status: st,
|
||||
offset: o,
|
||||
limit: l,
|
||||
metadata: m,
|
||||
name: n,
|
||||
tag: t,
|
||||
permission: p,
|
||||
listPerms: lp,
|
||||
userID: chi.URLParam(r, "userID"),
|
||||
id: id,
|
||||
name: name,
|
||||
tag: tag,
|
||||
status: status,
|
||||
metadata: meta,
|
||||
roleName: roleName,
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
accessType: accessType,
|
||||
order: order,
|
||||
dir: dir,
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
groupID: groupID,
|
||||
channelID: channelID,
|
||||
connType: connType,
|
||||
userID: userID,
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -104,19 +104,33 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
pm := clients.Page{
|
||||
Status: req.status,
|
||||
Offset: req.offset,
|
||||
Limit: req.limit,
|
||||
Name: req.name,
|
||||
Tag: req.tag,
|
||||
Permission: req.permission,
|
||||
Metadata: req.metadata,
|
||||
ListPerms: req.listPerms,
|
||||
Id: req.id,
|
||||
Name: req.name,
|
||||
Tag: req.tag,
|
||||
Status: req.status,
|
||||
Metadata: req.metadata,
|
||||
RoleName: req.roleName,
|
||||
RoleID: req.roleID,
|
||||
Actions: req.actions,
|
||||
AccessType: req.accessType,
|
||||
Order: req.order,
|
||||
Dir: req.dir,
|
||||
Offset: req.offset,
|
||||
Limit: req.limit,
|
||||
Group: req.groupID,
|
||||
Channel: req.channelID,
|
||||
ConnectionType: req.connType,
|
||||
}
|
||||
|
||||
var page clients.ClientsPage
|
||||
var err error
|
||||
switch req.userID != "" {
|
||||
case true:
|
||||
page, err = svc.ListUserClients(ctx, session, req.userID, pm)
|
||||
default:
|
||||
page, err = svc.ListClients(ctx, session, pm)
|
||||
}
|
||||
page, err := svc.ListClients(ctx, session, req.userID, pm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return clientsPageRes{}, err
|
||||
}
|
||||
|
||||
res := clientsPageRes{
|
||||
|
||||
@@ -743,7 +743,7 @@ func TestListClients(t *testing.T) {
|
||||
}
|
||||
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, "", mock.Anything).Return(tc.listClientsResponse, tc.err)
|
||||
svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, mock.Anything).Return(tc.listClientsResponse, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
|
||||
|
||||
@@ -71,29 +71,29 @@ func (req viewClientPermsReq) validate() error {
|
||||
}
|
||||
|
||||
type listClientsReq struct {
|
||||
status clients.Status
|
||||
offset uint64
|
||||
limit uint64
|
||||
name string
|
||||
tag string
|
||||
permission string
|
||||
visibility string
|
||||
userID string
|
||||
listPerms bool
|
||||
status clients.Status
|
||||
metadata clients.Metadata
|
||||
id string
|
||||
roleName string
|
||||
roleID string
|
||||
actions []string
|
||||
accessType string
|
||||
order string
|
||||
dir string
|
||||
offset uint64
|
||||
limit uint64
|
||||
groupID string
|
||||
channelID string
|
||||
connType string
|
||||
userID string
|
||||
}
|
||||
|
||||
func (req listClientsReq) validate() error {
|
||||
if req.limit > api.MaxLimitSize || req.limit < 1 {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
if req.visibility != "" &&
|
||||
req.visibility != api.AllVisibility &&
|
||||
req.visibility != api.MyVisibility &&
|
||||
req.visibility != api.SharedVisibility {
|
||||
return apiutil.ErrInvalidVisibilityType
|
||||
}
|
||||
|
||||
if len(req.name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
|
||||
@@ -210,14 +210,6 @@ func TestListClientsReqValidate(t *testing.T) {
|
||||
},
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "invalid visibility",
|
||||
req: listClientsReq{
|
||||
limit: 10,
|
||||
visibility: "invalid",
|
||||
},
|
||||
err: apiutil.ErrInvalidVisibilityType,
|
||||
},
|
||||
{
|
||||
desc: "name too long",
|
||||
req: listClientsReq{
|
||||
|
||||
+44
-22
@@ -13,6 +13,10 @@ import (
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
type CtxKey int
|
||||
|
||||
const ListDomainClients CtxKey = iota
|
||||
|
||||
type Connection struct {
|
||||
ClientID string
|
||||
ChannelID string
|
||||
@@ -35,11 +39,14 @@ type Repository interface {
|
||||
// RetrieveAll retrieves all clients.
|
||||
RetrieveAll(ctx context.Context, pm Page) (ClientsPage, error)
|
||||
|
||||
// RetrieveUserClients retrieve all clients of a given user id.
|
||||
RetrieveUserClients(ctx context.Context, domainID, userID string, pm Page) (ClientsPage, error)
|
||||
|
||||
// SearchClients retrieves clients based on search criteria.
|
||||
SearchClients(ctx context.Context, pm Page) (ClientsPage, error)
|
||||
|
||||
// RetrieveAllByIDs retrieves for given client IDs .
|
||||
RetrieveAllByIDs(ctx context.Context, pm Page) (ClientsPage, error)
|
||||
// RetrieveByIds
|
||||
RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error)
|
||||
|
||||
// Update updates the client name and metadata.
|
||||
Update(ctx context.Context, client Client) (Client, error)
|
||||
@@ -66,8 +73,6 @@ type Repository interface {
|
||||
// RetrieveBySecret retrieves a client based on the secret (key).
|
||||
RetrieveBySecret(ctx context.Context, key string) (Client, error)
|
||||
|
||||
RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error)
|
||||
|
||||
AddConnections(ctx context.Context, conns []Connection) error
|
||||
|
||||
RemoveConnections(ctx context.Context, conns []Connection) error
|
||||
@@ -105,8 +110,11 @@ type Service interface {
|
||||
// View retrieves client info for a given client ID and an authorized token.
|
||||
View(ctx context.Context, session authn.Session, id string) (Client, error)
|
||||
|
||||
// ListClients retrieves clients list for a valid auth token.
|
||||
ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error)
|
||||
// ListClients retrieves clients list for given page query.
|
||||
ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error)
|
||||
|
||||
// ListUserClients retrieves clients list for a given user id and page query.
|
||||
ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error)
|
||||
|
||||
// Update updates the client's name and metadata.
|
||||
Update(ctx context.Context, session authn.Session, client Client) (Client, error)
|
||||
@@ -161,8 +169,17 @@ type Client struct {
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
Identity string `json:"identity,omitempty"`
|
||||
// Extended
|
||||
ParentGroupPath string `json:"parent_group_path,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
AccessProviderId string `json:"access_provider_id,omitempty"`
|
||||
AccessProviderRoleId string `json:"access_provider_role_id,omitempty"`
|
||||
AccessProviderRoleName string `json:"access_provider_role_name,omitempty"`
|
||||
AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"`
|
||||
}
|
||||
|
||||
// ClientsPage contains page related metadata as well as list.
|
||||
@@ -182,21 +199,26 @@ type MembersPage struct {
|
||||
// Page contains the page metadata that helps navigation.
|
||||
|
||||
type Page struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Permission string `json:"permission,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
IDs []string `json:"ids,omitempty"`
|
||||
Identity string `json:"identity,omitempty"`
|
||||
ListPerms bool `json:"-"`
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
Id string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
Identity string `json:"identity,omitempty"`
|
||||
Group string `json:"group,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ConnectionType string `json:"connection_type,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
IDs []string `json:"-"`
|
||||
}
|
||||
|
||||
// Metadata represents arbitrary JSON.
|
||||
|
||||
@@ -205,7 +205,6 @@ func (vcpe viewClientPermsEvent) Encode() (map[string]interface{}, error) {
|
||||
}
|
||||
|
||||
type listClientEvent struct {
|
||||
reqUserID string
|
||||
clients.Page
|
||||
authn.Session
|
||||
}
|
||||
@@ -213,7 +212,6 @@ type listClientEvent struct {
|
||||
func (lce listClientEvent) Encode() (map[string]interface{}, error) {
|
||||
val := map[string]interface{}{
|
||||
"operation": clientList,
|
||||
"reqUserID": lce.reqUserID,
|
||||
"total": lce.Total,
|
||||
"offset": lce.Offset,
|
||||
"limit": lce.Limit,
|
||||
@@ -238,8 +236,51 @@ func (lce listClientEvent) Encode() (map[string]interface{}, error) {
|
||||
if lce.Tag != "" {
|
||||
val["tag"] = lce.Tag
|
||||
}
|
||||
if lce.Permission != "" {
|
||||
val["permission"] = lce.Permission
|
||||
if lce.Status.String() != "" {
|
||||
val["status"] = lce.Status.String()
|
||||
}
|
||||
if len(lce.IDs) > 0 {
|
||||
val["ids"] = lce.IDs
|
||||
}
|
||||
if lce.Identity != "" {
|
||||
val["identity"] = lce.Identity
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type listUserClientEvent struct {
|
||||
userID string
|
||||
clients.Page
|
||||
authn.Session
|
||||
}
|
||||
|
||||
func (lce listUserClientEvent) Encode() (map[string]interface{}, error) {
|
||||
val := map[string]interface{}{
|
||||
"operation": clientList,
|
||||
"req_user_id": lce.userID,
|
||||
"total": lce.Total,
|
||||
"offset": lce.Offset,
|
||||
"limit": lce.Limit,
|
||||
"domain": lce.DomainID,
|
||||
"user_id": lce.UserID,
|
||||
"token_type": lce.Type.String(),
|
||||
"super_admin": lce.SuperAdmin,
|
||||
}
|
||||
|
||||
if lce.Name != "" {
|
||||
val["name"] = lce.Name
|
||||
}
|
||||
if lce.Order != "" {
|
||||
val["order"] = lce.Order
|
||||
}
|
||||
if lce.Dir != "" {
|
||||
val["dir"] = lce.Dir
|
||||
}
|
||||
if lce.Metadata != nil {
|
||||
val["metadata"] = lce.Metadata
|
||||
}
|
||||
if lce.Tag != "" {
|
||||
val["tag"] = lce.Tag
|
||||
}
|
||||
if lce.Status.String() != "" {
|
||||
val["status"] = lce.Status.String()
|
||||
@@ -288,9 +329,6 @@ func (lcge listClientByGroupEvent) Encode() (map[string]interface{}, error) {
|
||||
if lcge.Tag != "" {
|
||||
val["tag"] = lcge.Tag
|
||||
}
|
||||
if lcge.Permission != "" {
|
||||
val["permission"] = lcge.Permission
|
||||
}
|
||||
if lcge.Status.String() != "" {
|
||||
val["status"] = lcge.Status.String()
|
||||
}
|
||||
|
||||
@@ -118,15 +118,31 @@ func (es *eventStore) View(ctx context.Context, session authn.Session, id string
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
cp, err := es.svc.ListClients(ctx, session, reqUserID, pm)
|
||||
func (es *eventStore) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
|
||||
cp, err := es.svc.ListClients(ctx, session, pm)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
event := listClientEvent{
|
||||
reqUserID: reqUserID,
|
||||
Page: pm,
|
||||
Session: session,
|
||||
pm,
|
||||
session,
|
||||
}
|
||||
if err := es.Publish(ctx, event); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
cp, err := es.svc.ListUserClients(ctx, session, userID, pm)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
event := listUserClientEvent{
|
||||
userID,
|
||||
pm,
|
||||
session,
|
||||
}
|
||||
if err := es.Publish(ctx, event); err != nil {
|
||||
return cp, err
|
||||
|
||||
@@ -129,7 +129,29 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
|
||||
return am.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
|
||||
if session.Type == authn.PersonalAccessToken {
|
||||
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
|
||||
UserID: session.UserID,
|
||||
PatID: session.PatID,
|
||||
PlatformEntityType: auth.PlatformDomainsScope,
|
||||
OptionalDomainID: session.DomainID,
|
||||
OptionalDomainEntityType: auth.DomainClientsScope,
|
||||
Operation: auth.ListOp,
|
||||
EntityIDs: auth.AnyIDs{}.Values(),
|
||||
}); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
|
||||
session.SuperAdmin = true
|
||||
}
|
||||
|
||||
return am.svc.ListClients(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
if session.Type == authn.PersonalAccessToken {
|
||||
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
|
||||
UserID: session.UserID,
|
||||
@@ -145,10 +167,10 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
|
||||
}
|
||||
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
|
||||
session.SuperAdmin = true
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
return am.svc.ListClients(ctx, session, reqUserID, pm)
|
||||
return am.svc.ListUserClients(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
|
||||
|
||||
@@ -65,11 +65,10 @@ func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id
|
||||
return lm.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (cp clients.ClientsPage, err error) {
|
||||
func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (cp clients.ClientsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("user_id", reqUserID),
|
||||
slog.Group("page",
|
||||
slog.Uint64("limit", pm.Limit),
|
||||
slog.Uint64("offset", pm.Offset),
|
||||
@@ -83,7 +82,28 @@ func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Sess
|
||||
}
|
||||
lm.logger.Info("List clients completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListClients(ctx, session, reqUserID, pm)
|
||||
return lm.svc.ListClients(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (cp clients.ClientsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("user_id", userID),
|
||||
slog.Group("page",
|
||||
slog.Uint64("limit", pm.Limit),
|
||||
slog.Uint64("offset", pm.Offset),
|
||||
slog.Uint64("total", cp.Total),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List clients failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List clients completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListUserClients(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (c clients.Client, err error) {
|
||||
|
||||
@@ -49,12 +49,20 @@ func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id
|
||||
return ms.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_clients").Add(1)
|
||||
ms.latency.With("method", "list_clients").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListClients(ctx, session, reqUserID, pm)
|
||||
return ms.svc.ListClients(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_user_clients").Add(1)
|
||||
ms.latency.With("method", "list_user_clients").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListUserClients(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
|
||||
|
||||
+28
-28
@@ -312,34 +312,6 @@ func (_m *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clients
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveAllByIDs provides a mock function with given fields: ctx, pm
|
||||
func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ret := _m.Called(ctx, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveAllByIDs")
|
||||
}
|
||||
|
||||
var r0 clients.ClientsPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, clients.Page) (clients.ClientsPage, error)); ok {
|
||||
return rf(ctx, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, clients.Page) clients.ClientsPage); ok {
|
||||
r0 = rf(ctx, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(clients.ClientsPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, clients.Page) error); ok {
|
||||
r1 = rf(ctx, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveAllRoles provides a mock function with given fields: ctx, entityID, limit, offset
|
||||
func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) {
|
||||
ret := _m.Called(ctx, entityID, limit, offset)
|
||||
@@ -577,6 +549,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveUserClients provides a mock function with given fields: ctx, domainID, userID, pm
|
||||
func (_m *Repository) RetrieveUserClients(ctx context.Context, domainID string, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ret := _m.Called(ctx, domainID, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveUserClients")
|
||||
}
|
||||
|
||||
var r0 clients.ClientsPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) (clients.ClientsPage, error)); ok {
|
||||
return rf(ctx, domainID, userID, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) clients.ClientsPage); ok {
|
||||
r0 = rf(ctx, domainID, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(clients.ClientsPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, clients.Page) error); ok {
|
||||
r1 = rf(ctx, domainID, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RoleAddActions provides a mock function with given fields: ctx, role, actions
|
||||
func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) {
|
||||
ret := _m.Called(ctx, role, actions)
|
||||
|
||||
@@ -198,27 +198,55 @@ func (_m *Service) ListAvailableActions(ctx context.Context, session authn.Sessi
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListClients provides a mock function with given fields: ctx, session, reqUserID, pm
|
||||
func (_m *Service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ret := _m.Called(ctx, session, reqUserID, pm)
|
||||
// ListClients provides a mock function with given fields: ctx, session, pm
|
||||
func (_m *Service) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ret := _m.Called(ctx, session, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListClients")
|
||||
}
|
||||
|
||||
var r0 clients.ClientsPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) (clients.ClientsPage, error)); ok {
|
||||
return rf(ctx, session, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) clients.ClientsPage); ok {
|
||||
r0 = rf(ctx, session, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(clients.ClientsPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, clients.Page) error); ok {
|
||||
r1 = rf(ctx, session, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListUserClients provides a mock function with given fields: ctx, session, userID, pm
|
||||
func (_m *Service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ret := _m.Called(ctx, session, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserClients")
|
||||
}
|
||||
|
||||
var r0 clients.ClientsPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) (clients.ClientsPage, error)); ok {
|
||||
return rf(ctx, session, reqUserID, pm)
|
||||
return rf(ctx, session, userID, pm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) clients.ClientsPage); ok {
|
||||
r0 = rf(ctx, session, reqUserID, pm)
|
||||
r0 = rf(ctx, session, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(clients.ClientsPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, clients.Page) error); ok {
|
||||
r1 = rf(ctx, session, reqUserID, pm)
|
||||
r1 = rf(ctx, session, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
+441
-105
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/absmach/supermq/pkg/postgres"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -247,6 +248,350 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) RetrieveUserClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
return repo.retrieveClients(ctx, domainID, userID, pm)
|
||||
}
|
||||
|
||||
func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
pageQuery, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
bq := repo.userClientBaseQuery(domainID, userID)
|
||||
|
||||
q := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.identity,
|
||||
c.secret,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
FROM
|
||||
final_clients c
|
||||
%s
|
||||
`, bq, pageQuery)
|
||||
|
||||
q = applyOrdering(q, pm)
|
||||
|
||||
dbPage, err := ToDBClientsPage(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []clients.Client
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
|
||||
connJoinQuery := ""
|
||||
if pm.Channel != "" {
|
||||
connJoinQuery = "JOIN connection conn ON conn.client_id = c.id"
|
||||
}
|
||||
cq := fmt.Sprintf(`%s
|
||||
SELECT COUNT(*) AS total_count
|
||||
FROM (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.identity,
|
||||
c.secret,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
c.parent_group_path,
|
||||
c.role_id,
|
||||
c.role_name,
|
||||
c.actions,
|
||||
c.access_type,
|
||||
c.access_provider_id,
|
||||
c.access_provider_role_id,
|
||||
c.access_provider_role_name,
|
||||
c.access_provider_role_actions
|
||||
FROM
|
||||
final_clients c
|
||||
%s
|
||||
%s
|
||||
) AS subquery;
|
||||
`, bq, connJoinQuery, pageQuery)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := clients.ClientsPage{
|
||||
Clients: items,
|
||||
Page: clients.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string {
|
||||
return fmt.Sprintf(`
|
||||
WITH direct_clients AS (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.identity,
|
||||
c.secret,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
text2ltree('') as parent_group_path,
|
||||
cr.id AS role_id,
|
||||
cr."name" AS role_name,
|
||||
array_agg(cra."action") AS actions,
|
||||
'direct' as access_type,
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
FROM
|
||||
clients_role_members crm
|
||||
JOIN
|
||||
clients_role_actions cra ON cra.role_id = crm.role_id
|
||||
JOIN
|
||||
clients_roles cr ON cr.id = crm.role_id
|
||||
JOIN
|
||||
clients c ON c.id = cr.entity_id
|
||||
WHERE
|
||||
crm.member_id = '%s'
|
||||
AND c.domain_id = '%s'
|
||||
GROUP BY
|
||||
cr.entity_id, crm.member_id, cr.id, cr."name", c.id
|
||||
),
|
||||
direct_groups AS (
|
||||
SELECT
|
||||
g.*,
|
||||
gr.entity_id AS entity_id,
|
||||
grm.member_id AS member_id,
|
||||
gr.id AS role_id,
|
||||
gr."name" AS role_name,
|
||||
array_agg(gra."action") AS actions
|
||||
FROM
|
||||
groups_role_members grm
|
||||
JOIN
|
||||
groups_role_actions gra ON gra.role_id = grm.role_id
|
||||
JOIN
|
||||
groups_roles gr ON gr.id = grm.role_id
|
||||
JOIN
|
||||
"groups" g ON g.id = gr.entity_id
|
||||
WHERE
|
||||
grm.member_id = '%s'
|
||||
AND g.domain_id = '%s'
|
||||
GROUP BY
|
||||
gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id
|
||||
),
|
||||
direct_groups_with_subgroup AS (
|
||||
SELECT
|
||||
*
|
||||
FROM direct_groups
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM unnest(direct_groups.actions) AS action
|
||||
WHERE action LIKE 'subgroup_%%'
|
||||
)
|
||||
),
|
||||
indirect_child_groups AS (
|
||||
SELECT
|
||||
DISTINCT indirect_child_groups.id as child_id,
|
||||
indirect_child_groups.*,
|
||||
dgws.id as access_provider_id,
|
||||
dgws.role_id as access_provider_role_id,
|
||||
dgws.role_name as access_provider_role_name,
|
||||
dgws.actions as access_provider_role_actions
|
||||
FROM
|
||||
direct_groups_with_subgroup dgws
|
||||
JOIN
|
||||
groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path
|
||||
WHERE
|
||||
indirect_child_groups.domain_id = '%s'
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM direct_groups_with_subgroup dgws
|
||||
WHERE dgws.id = indirect_child_groups.id
|
||||
)
|
||||
),
|
||||
final_groups AS (
|
||||
SELECT
|
||||
id,
|
||||
parent_id,
|
||||
domain_id,
|
||||
"name",
|
||||
description,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
updated_by,
|
||||
status,
|
||||
"path",
|
||||
role_id,
|
||||
role_name,
|
||||
actions,
|
||||
'direct_group' AS access_type,
|
||||
'' AS access_provider_id,
|
||||
'' AS access_provider_role_id,
|
||||
'' AS access_provider_role_name,
|
||||
array[]::::text[] AS access_provider_role_actions
|
||||
FROM
|
||||
direct_groups
|
||||
UNION
|
||||
SELECT
|
||||
id,
|
||||
parent_id,
|
||||
domain_id,
|
||||
"name",
|
||||
description,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
updated_by,
|
||||
status,
|
||||
"path",
|
||||
'' AS role_id,
|
||||
'' AS role_name,
|
||||
array[]::::text[] AS actions,
|
||||
'indirect_group' AS access_type,
|
||||
access_provider_id,
|
||||
access_provider_role_id,
|
||||
access_provider_role_name,
|
||||
access_provider_role_actions
|
||||
FROM
|
||||
indirect_child_groups
|
||||
),
|
||||
group_direct_clients AS (
|
||||
SELECT
|
||||
c.id,
|
||||
c.name,
|
||||
c.domain_id,
|
||||
c.parent_group_id,
|
||||
c.identity,
|
||||
c.secret,
|
||||
c.tags,
|
||||
c.metadata,
|
||||
c.created_at,
|
||||
c.updated_at,
|
||||
c.updated_by,
|
||||
c.status,
|
||||
g.path AS parent_group_path,
|
||||
g.role_id,
|
||||
g.role_name,
|
||||
g.actions,
|
||||
g.access_type,
|
||||
g.access_provider_id,
|
||||
g.access_provider_role_id,
|
||||
g.access_provider_role_name,
|
||||
g.access_provider_role_actions
|
||||
FROM
|
||||
final_groups g
|
||||
JOIN
|
||||
clients c ON c.parent_group_id = g.id
|
||||
WHERE
|
||||
c.id NOT IN (SELECT id FROM direct_clients)
|
||||
UNION
|
||||
SELECT
|
||||
dc.id,
|
||||
dc.name,
|
||||
dc.domain_id,
|
||||
dc.parent_group_id,
|
||||
dc.identity,
|
||||
dc.secret,
|
||||
dc.tags,
|
||||
dc.metadata,
|
||||
dc.created_at,
|
||||
dc.updated_at,
|
||||
dc.updated_by,
|
||||
dc.status,
|
||||
dc.parent_group_path,
|
||||
dc.role_id,
|
||||
dc.role_name,
|
||||
dc.actions,
|
||||
dc.access_type,
|
||||
dc.access_provider_id,
|
||||
dc.access_provider_role_id,
|
||||
dc.access_provider_role_name,
|
||||
dc.access_provider_role_actions
|
||||
FROM
|
||||
direct_clients AS dc
|
||||
),
|
||||
final_clients AS (
|
||||
SELECT
|
||||
gdc.id,
|
||||
gdc.name,
|
||||
gdc.domain_id,
|
||||
gdc.parent_group_id,
|
||||
gdc.identity,
|
||||
gdc.secret,
|
||||
gdc.tags,
|
||||
gdc.metadata,
|
||||
gdc.created_at,
|
||||
gdc.updated_at,
|
||||
gdc.updated_by,
|
||||
gdc.status,
|
||||
gdc.parent_group_path,
|
||||
gdc.role_id,
|
||||
gdc.role_name,
|
||||
gdc.actions,
|
||||
gdc.access_type,
|
||||
gdc.access_provider_id,
|
||||
gdc.access_provider_role_id,
|
||||
gdc.access_provider_role_name,
|
||||
gdc.access_provider_role_actions
|
||||
FROM
|
||||
group_direct_clients AS gdc
|
||||
)
|
||||
`, userID, domainID, userID, domainID, domainID)
|
||||
}
|
||||
|
||||
func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
|
||||
query, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
@@ -302,64 +647,6 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
|
||||
if (len(pm.IDs) == 0) && (pm.Domain == "") {
|
||||
return clients.ClientsPage{
|
||||
Page: clients.Page{Total: pm.Total, Offset: pm.Offset, Limit: pm.Limit},
|
||||
}, nil
|
||||
}
|
||||
query, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
query = applyOrdering(query, pm)
|
||||
|
||||
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status,
|
||||
c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
|
||||
|
||||
dbPage, err := ToDBClientsPage(pm)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var items []clients.Client
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, err
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
}
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, query)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := clients.ClientsPage{
|
||||
Clients: items,
|
||||
Page: clients.Page{
|
||||
Total: total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
},
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (repo *clientRepo) update(ctx context.Context, client clients.Client, query string) (clients.Client, error) {
|
||||
dbc, err := ToDBClient(client)
|
||||
if err != nil {
|
||||
@@ -402,18 +689,27 @@ func (repo *clientRepo) Delete(ctx context.Context, clientIDs ...string) error {
|
||||
}
|
||||
|
||||
type DBClient struct {
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Identity string `db:"identity"`
|
||||
Domain string `db:"domain_id"`
|
||||
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
|
||||
Secret string `db:"secret"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
CreatedAt time.Time `db:"created_at,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
Status clients.Status `db:"status,omitempty"`
|
||||
ID string `db:"id"`
|
||||
Name string `db:"name,omitempty"`
|
||||
Tags pgtype.TextArray `db:"tags,omitempty"`
|
||||
Identity string `db:"identity"`
|
||||
Domain string `db:"domain_id"`
|
||||
ParentGroup sql.NullString `db:"parent_group_id,omitempty"`
|
||||
Secret string `db:"secret"`
|
||||
Metadata []byte `db:"metadata,omitempty"`
|
||||
CreatedAt time.Time `db:"created_at,omitempty"`
|
||||
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
|
||||
UpdatedBy *string `db:"updated_by,omitempty"`
|
||||
Status clients.Status `db:"status,omitempty"`
|
||||
ParentGroupPath string `db:"parent_group_path,omitempty"`
|
||||
RoleID string `db:"role_id,omitempty"`
|
||||
RoleName string `db:"role_name,omitempty"`
|
||||
Actions pq.StringArray `db:"actions,omitempty"`
|
||||
AccessType string `db:"access_type,omitempty"`
|
||||
AccessProviderId string `db:"access_provider_id,omitempty"`
|
||||
AccessProviderRoleId string `db:"access_provider_role_id,omitempty"`
|
||||
AccessProviderRoleName string `db:"access_provider_role_name,omitempty"`
|
||||
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"`
|
||||
}
|
||||
|
||||
func ToDBClient(c clients.Client) (DBClient, error) {
|
||||
@@ -484,11 +780,19 @@ func ToClient(t DBClient) (clients.Client, error) {
|
||||
Identity: t.Identity,
|
||||
Secret: t.Secret,
|
||||
},
|
||||
Metadata: metadata,
|
||||
CreatedAt: t.CreatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
UpdatedBy: updatedBy,
|
||||
Status: t.Status,
|
||||
Metadata: metadata,
|
||||
CreatedAt: t.CreatedAt,
|
||||
UpdatedAt: updatedAt,
|
||||
UpdatedBy: updatedBy,
|
||||
Status: t.Status,
|
||||
RoleID: t.RoleID,
|
||||
RoleName: t.RoleName,
|
||||
Actions: t.Actions,
|
||||
AccessType: t.AccessType,
|
||||
AccessProviderId: t.AccessProviderId,
|
||||
AccessProviderRoleId: t.AccessProviderRoleId,
|
||||
AccessProviderRoleName: t.AccessProviderRoleName,
|
||||
AccessProviderRoleActions: t.AccessProviderRoleActions,
|
||||
}
|
||||
return cli, nil
|
||||
}
|
||||
@@ -499,31 +803,42 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) {
|
||||
return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return dbClientsPage{
|
||||
Name: pm.Name,
|
||||
Identity: pm.Identity,
|
||||
Id: pm.Id,
|
||||
Metadata: data,
|
||||
Domain: pm.Domain,
|
||||
Total: pm.Total,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
Status: pm.Status,
|
||||
Tag: pm.Tag,
|
||||
Offset: pm.Offset,
|
||||
Limit: pm.Limit,
|
||||
Name: pm.Name,
|
||||
Identity: pm.Identity,
|
||||
Id: pm.Id,
|
||||
Metadata: data,
|
||||
Domain: pm.Domain,
|
||||
Status: pm.Status,
|
||||
Tag: pm.Tag,
|
||||
GroupID: pm.Group,
|
||||
ChannelID: pm.Channel,
|
||||
RoleName: pm.RoleName,
|
||||
ConnType: pm.ConnectionType,
|
||||
RoleID: pm.RoleID,
|
||||
Actions: pm.Actions,
|
||||
AccessType: pm.AccessType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dbClientsPage struct {
|
||||
Total uint64 `db:"total"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Name string `db:"name"`
|
||||
Id string `db:"id"`
|
||||
Domain string `db:"domain_id"`
|
||||
Identity string `db:"identity"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tag string `db:"tag"`
|
||||
Status clients.Status `db:"status"`
|
||||
GroupID string `db:"group_id"`
|
||||
Limit uint64 `db:"limit"`
|
||||
Offset uint64 `db:"offset"`
|
||||
Name string `db:"name"`
|
||||
Id string `db:"id"`
|
||||
Domain string `db:"domain_id"`
|
||||
Identity string `db:"identity"`
|
||||
Metadata []byte `db:"metadata"`
|
||||
Tag string `db:"tag"`
|
||||
Status clients.Status `db:"status"`
|
||||
GroupID string `db:"group_id"`
|
||||
ChannelID string `db:"channel_id"`
|
||||
ConnType string `db:"type"`
|
||||
RoleName string `db:"role_name"`
|
||||
RoleID string `db:"role_id"`
|
||||
Actions pq.StringArray `db:"actions"`
|
||||
AccessType string `db:"access_type"`
|
||||
}
|
||||
|
||||
func PageQuery(pm clients.Page) (string, error) {
|
||||
@@ -534,36 +849,57 @@ func PageQuery(pm clients.Page) (string, error) {
|
||||
|
||||
var query []string
|
||||
if pm.Name != "" {
|
||||
query = append(query, "name ILIKE '%' || :name || '%'")
|
||||
query = append(query, "c.name ILIKE '%' || :name || '%'")
|
||||
}
|
||||
if pm.Identity != "" {
|
||||
query = append(query, "identity ILIKE '%' || :identity || '%'")
|
||||
query = append(query, "c.identity ILIKE '%' || :identity || '%'")
|
||||
}
|
||||
if pm.Id != "" {
|
||||
query = append(query, "id ILIKE '%' || :id || '%'")
|
||||
query = append(query, "c.id ILIKE '%' || :id || '%'")
|
||||
}
|
||||
if pm.Tag != "" {
|
||||
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
|
||||
}
|
||||
// If there are search params presents, use search and ignore other options.
|
||||
// Always combine role with search params, so len(query) > 1.
|
||||
if len(query) > 1 {
|
||||
return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil
|
||||
}
|
||||
|
||||
if mq != "" {
|
||||
query = append(query, mq)
|
||||
}
|
||||
|
||||
if len(pm.IDs) != 0 {
|
||||
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
query = append(query, fmt.Sprintf("c.id IN ('%s')", strings.Join(pm.IDs, "','")))
|
||||
}
|
||||
|
||||
if pm.Status != clients.AllStatus {
|
||||
query = append(query, "c.status = :status")
|
||||
}
|
||||
if pm.Domain != "" {
|
||||
query = append(query, "c.domain_id = :domain_id")
|
||||
}
|
||||
if pm.Group != "" {
|
||||
query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ")
|
||||
}
|
||||
if pm.Channel != "" {
|
||||
query = append(query, "conn.channel_id = :channel_id ")
|
||||
if pm.ConnectionType != "" {
|
||||
query = append(query, "conn.type = :conn_type ")
|
||||
}
|
||||
}
|
||||
if pm.AccessType != "" {
|
||||
query = append(query, "c.access_type = :access_type")
|
||||
}
|
||||
if pm.RoleID != "" {
|
||||
query = append(query, "c.role_id = :role_id")
|
||||
}
|
||||
if pm.RoleName != "" {
|
||||
query = append(query, "c.role_name = :role_name")
|
||||
}
|
||||
if len(pm.Actions) != 0 {
|
||||
query = append(query, "c.actions @> :actions")
|
||||
}
|
||||
if len(pm.Metadata) > 0 {
|
||||
query = append(query, "c.metadata @> :metadata")
|
||||
}
|
||||
|
||||
var emq string
|
||||
if len(query) > 0 {
|
||||
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
|
||||
|
||||
@@ -1626,274 +1626,7 @@ func TestSearchClients(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveAllByIDs(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM clients")
|
||||
require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
num := 200
|
||||
|
||||
var items []clients.Client
|
||||
for i := 0; i < num; i++ {
|
||||
name := namegen.Generate()
|
||||
client := clients.Client{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
Name: name,
|
||||
Credentials: clients.Credentials{
|
||||
Identity: name + emailSuffix,
|
||||
Secret: testsutil.GenerateUUID(t),
|
||||
},
|
||||
Tags: namegen.GenerateMultiple(5),
|
||||
Metadata: map[string]interface{}{"name": name},
|
||||
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
|
||||
Status: clients.EnabledStatus,
|
||||
}
|
||||
_, err := repo.Save(context.Background(), client)
|
||||
require.Nil(t, err, fmt.Sprintf("add new client: expected nil got %s\n", err))
|
||||
items = append(items, client)
|
||||
}
|
||||
|
||||
page, err := repo.RetrieveAll(context.Background(), clients.Page{Offset: 0, Limit: uint64(num)})
|
||||
require.Nil(t, err, fmt.Sprintf("retrieve all clients unexpected error: %s", err))
|
||||
assert.Equal(t, uint64(num), page.Total)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
page clients.Page
|
||||
response clients.ClientsPage
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "successfully",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
IDs: getIDs(items[0:3]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 3,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: items[0:3],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with empty ids",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
IDs: []string{},
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client(nil),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with empty ids but with domain id",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Domain: items[0].Domain,
|
||||
IDs: []string{},
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 1,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client{items[0]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with offset only",
|
||||
page: clients.Page{
|
||||
Offset: 10,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 20,
|
||||
Offset: 10,
|
||||
Limit: 0,
|
||||
},
|
||||
Clients: []clients.Client(nil),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with limit only",
|
||||
page: clients.Page{
|
||||
Limit: 10,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 20,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: items[0:10],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with offset out of range",
|
||||
page: clients.Page{
|
||||
Offset: 1000,
|
||||
Limit: 50,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 20,
|
||||
Offset: 1000,
|
||||
Limit: 50,
|
||||
},
|
||||
Clients: []clients.Client(nil),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with offset and limit out of range",
|
||||
page: clients.Page{
|
||||
Offset: 15,
|
||||
Limit: 10,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 20,
|
||||
Offset: 15,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: items[15:20],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with limit out of range",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 1000,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 20,
|
||||
Offset: 0,
|
||||
Limit: 1000,
|
||||
},
|
||||
Clients: items[:20],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with name",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Name: items[0].Name,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 1,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client{items[0]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with domain id",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Domain: items[0].Domain,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 1,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client{items[0]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with metadata",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Metadata: items[0].Metadata,
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 1,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client{items[0]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "with invalid metadata",
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Metadata: map[string]interface{}{
|
||||
"key": make(chan int),
|
||||
},
|
||||
IDs: getIDs(items[0:20]),
|
||||
},
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 0,
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Clients: []clients.Client(nil),
|
||||
},
|
||||
err: errors.ErrMalformedEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
|
||||
assert.Equal(t, c.response.Total, response.Total)
|
||||
assert.Equal(t, c.response.Limit, response.Limit)
|
||||
assert.Equal(t, c.response.Offset, response.Offset)
|
||||
expected := stripClientDetails(c.response.Clients)
|
||||
got := stripClientDetails(response.Clients)
|
||||
assert.ElementsMatch(t, expected, got)
|
||||
default:
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievByIDs(t *testing.T) {
|
||||
func TestRetrieveByIDs(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM clients")
|
||||
require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err))
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
@@ -60,5 +61,12 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
|
||||
clientsMigration.Migrations = append(clientsMigration.Migrations, clientsRolesMigration.Migrations...)
|
||||
|
||||
groupsMigration, err := gpostgres.Migration()
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, err
|
||||
}
|
||||
|
||||
clientsMigration.Migrations = append(clientsMigration.Migrations, groupsMigration.Migrations...)
|
||||
|
||||
return clientsMigration, nil
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func NewRolesOperationPermissionMap() map[svcutil.Operation]svcutil.Permission {
|
||||
const (
|
||||
// External Permission for domains.
|
||||
domainCreateClientPermission = "client_create_permission"
|
||||
domainListClientsPermission = "list_clients_permission"
|
||||
domainListClientsPermission = "client_read_permission"
|
||||
// External Permission for groups.
|
||||
groupSetChildClientPermission = "client_create_permission"
|
||||
groupRemoveChildClientPermission = "client_create_permission"
|
||||
|
||||
+14
-100
@@ -12,13 +12,11 @@ import (
|
||||
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
|
||||
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
smqauth "github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -131,113 +129,29 @@ func (svc service) View(ctx context.Context, session authn.Session, id string) (
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (svc service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error) {
|
||||
var ids []string
|
||||
var err error
|
||||
switch {
|
||||
case (reqUserID != "" && reqUserID != session.UserID):
|
||||
rtids, err := svc.listClientIDs(ctx, smqauth.EncodeDomainUserID(session.DomainID, reqUserID), pm.Permission)
|
||||
func (svc service) ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error) {
|
||||
switch session.SuperAdmin {
|
||||
case true:
|
||||
cp, err := svc.repo.RetrieveAll(ctx, pm)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
ids, err = svc.filterAllowedClientIDs(ctx, session.DomainUserID, pm.Permission, rtids)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return cp, nil
|
||||
default:
|
||||
switch session.SuperAdmin {
|
||||
case true:
|
||||
pm.Domain = session.DomainID
|
||||
default:
|
||||
ids, err = svc.listClientIDs(ctx, session.DomainUserID, pm.Permission)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, session.UserID, pm)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) == 0 && pm.Domain == "" {
|
||||
return ClientsPage{}, nil
|
||||
}
|
||||
pm.IDs = ids
|
||||
tp, err := svc.repo.SearchClients(ctx, pm)
|
||||
func (svc service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error) {
|
||||
cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, userID, pm)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if pm.ListPerms && len(tp.Clients) > 0 {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
for i := range tp.Clients {
|
||||
// Copying loop variable "i" to avoid "loop variable captured by func literal"
|
||||
iter := i
|
||||
g.Go(func() error {
|
||||
return svc.retrievePermissions(ctx, session.DomainUserID, &tp.Clients[iter])
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return ClientsPage{}, err
|
||||
}
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
// Experimental functions used for async calling of svc.listUserClientPermission. This might be helpful during listing of large number of entities.
|
||||
func (svc service) retrievePermissions(ctx context.Context, userID string, client *Client) error {
|
||||
permissions, err := svc.listUserClientPermission(ctx, userID, client.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client.Permissions = permissions
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) {
|
||||
permissions, err := svc.policy.ListPermissions(ctx, policies.Policy{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Object: clientID,
|
||||
ObjectType: policies.ClientType,
|
||||
}, []string{})
|
||||
if err != nil {
|
||||
return []string{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
func (svc service) listClientIDs(ctx context.Context, userID, permission string) ([]string, error) {
|
||||
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Permission: permission,
|
||||
ObjectType: policies.ClientType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
return tids.Policies, nil
|
||||
}
|
||||
|
||||
func (svc service) filterAllowedClientIDs(ctx context.Context, userID, permission string, clientIDs []string) ([]string, error) {
|
||||
var ids []string
|
||||
tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Permission: permission,
|
||||
ObjectType: policies.ClientType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
for _, clientID := range clientIDs {
|
||||
for _, tid := range tids.Policies {
|
||||
if clientID == tid {
|
||||
ids = append(ids, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (svc service) Update(ctx context.Context, session authn.Session, cli Client) (Client, error) {
|
||||
|
||||
+24
-99
@@ -351,7 +351,6 @@ func TestListClients(t *testing.T) {
|
||||
adminID := testsutil.GenerateUUID(t)
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
nonAdminID := testsutil.GenerateUUID(t)
|
||||
client.Permissions = []string{"read", "edit"}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -375,9 +374,8 @@ func TestListClients(t *testing.T) {
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
|
||||
retrieveAllResponse: clients.ClientsPage{
|
||||
@@ -388,7 +386,6 @@ func TestListClients(t *testing.T) {
|
||||
},
|
||||
Clients: []clients.Client{client, client},
|
||||
},
|
||||
listPermissionsResponse: client.Permissions,
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 2,
|
||||
@@ -405,9 +402,8 @@ func TestListClients(t *testing.T) {
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
|
||||
retrieveAllResponse: clients.ClientsPage{},
|
||||
@@ -415,39 +411,14 @@ func TestListClients(t *testing.T) {
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as non admin with failed to list permissions",
|
||||
userKind: "non-admin",
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
|
||||
retrieveAllResponse: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Clients: []clients.Client{client, client},
|
||||
},
|
||||
listPermissionsResponse: []string{},
|
||||
response: clients.ClientsPage{},
|
||||
listPermissionsErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as non admin with failed super admin",
|
||||
userKind: "non-admin",
|
||||
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
|
||||
id: nonAdminID,
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
response: clients.ClientsPage{},
|
||||
listObjectsResponse: policysvc.PolicyPage{},
|
||||
@@ -458,10 +429,10 @@ func TestListClients(t *testing.T) {
|
||||
userKind: "non-admin",
|
||||
id: nonAdminID,
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
response: clients.ClientsPage{},
|
||||
listObjectsResponse: policysvc.PolicyPage{},
|
||||
listObjectsErr: svcerr.ErrNotFound,
|
||||
@@ -470,15 +441,13 @@ func TestListClients(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
listAllObjectsCall := pService.On("ListAllObjects", mock.Anything, mock.Anything).Return(tc.listObjectsResponse, tc.listObjectsErr)
|
||||
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
|
||||
page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page)
|
||||
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
retrieveUserClientsCall := repo.On("RetrieveUserClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
listAllObjectsCall.Unset()
|
||||
retrieveAllCall.Unset()
|
||||
listPermissionsCall.Unset()
|
||||
retrieveUserClientsCall.Unset()
|
||||
}
|
||||
|
||||
cases2 := []struct {
|
||||
@@ -503,10 +472,9 @@ func TestListClients(t *testing.T) {
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Domain: domainID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
|
||||
retrieveAllResponse: clients.ClientsPage{
|
||||
@@ -517,7 +485,6 @@ func TestListClients(t *testing.T) {
|
||||
},
|
||||
Clients: []clients.Client{client, client},
|
||||
},
|
||||
listPermissionsResponse: client.Permissions,
|
||||
response: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 2,
|
||||
@@ -534,50 +501,24 @@ func TestListClients(t *testing.T) {
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Domain: domainID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{},
|
||||
retrieveAllResponse: clients.ClientsPage{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as admin with failed to list permissions",
|
||||
userKind: "admin",
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Domain: domainID,
|
||||
},
|
||||
listObjectsResponse: policysvc.PolicyPage{},
|
||||
retrieveAllResponse: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Total: 2,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Clients: []clients.Client{client, client},
|
||||
},
|
||||
listPermissionsResponse: []string{},
|
||||
listPermissionsErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list all clients as admin with failed to list clients",
|
||||
userKind: "admin",
|
||||
id: adminID,
|
||||
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
|
||||
page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
ListPerms: true,
|
||||
Domain: domainID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Domain: domainID,
|
||||
},
|
||||
retrieveAllResponse: clients.ClientsPage{},
|
||||
retrieveAllErr: repoerr.ErrNotFound,
|
||||
@@ -586,27 +527,11 @@ func TestListClients(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases2 {
|
||||
listAllObjectsCall := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
|
||||
SubjectType: policysvc.UserType,
|
||||
Subject: tc.session.DomainID + "_" + adminID,
|
||||
Permission: "",
|
||||
ObjectType: policysvc.ClientType,
|
||||
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
|
||||
listAllObjectsCall2 := pService.On("ListAllObjects", context.Background(), policysvc.Policy{
|
||||
SubjectType: policysvc.UserType,
|
||||
Subject: tc.session.UserID,
|
||||
Permission: "",
|
||||
ObjectType: policysvc.ClientType,
|
||||
}).Return(tc.listObjectsResponse, tc.listObjectsErr)
|
||||
retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr)
|
||||
page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page)
|
||||
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
listAllObjectsCall.Unset()
|
||||
listAllObjectsCall2.Unset()
|
||||
retrieveAllCall.Unset()
|
||||
listPermissionsCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ func TestStatusUnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserMarshalJSON(t *testing.T) {
|
||||
func TestClientMarshalJSON(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
expected []byte
|
||||
|
||||
@@ -47,10 +47,16 @@ func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id
|
||||
}
|
||||
|
||||
// ListClients traces the "ListClients" operation of the wrapped clients.Service.
|
||||
func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_clients")
|
||||
defer span.End()
|
||||
return tm.svc.ListClients(ctx, session, reqUserID, pm)
|
||||
return tm.svc.ListClients(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_clients")
|
||||
defer span.End()
|
||||
return tm.svc.ListUserClients(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
// Update traces the "Update" operation of the wrapped clients.Service.
|
||||
|
||||
@@ -25,11 +25,13 @@ import (
|
||||
"github.com/absmach/supermq/channels/postgres"
|
||||
pChannels "github.com/absmach/supermq/channels/private"
|
||||
"github.com/absmach/supermq/channels/tracing"
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
@@ -74,6 +76,7 @@ type config struct {
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"channels"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
@@ -224,6 +227,15 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
gdatabase := pg.NewDatabase(db, dbConfig, tracer)
|
||||
grepo := gpostgres.New(gdatabase)
|
||||
|
||||
if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create groups event store : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
grpcServerConfig := server.Config{Port: defSvcGRPCPort}
|
||||
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err))
|
||||
|
||||
@@ -27,12 +27,14 @@ import (
|
||||
"github.com/absmach/supermq/clients/postgres"
|
||||
pClients "github.com/absmach/supermq/clients/private"
|
||||
"github.com/absmach/supermq/clients/tracing"
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
|
||||
gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
"github.com/absmach/supermq/pkg/policies"
|
||||
@@ -82,6 +84,7 @@ type config struct {
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"clients"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
@@ -241,6 +244,15 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
gdatabase := pg.NewDatabase(db, dbConfig, tracer)
|
||||
grepo := gpostgres.New(gdatabase)
|
||||
|
||||
if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to create groups event store : %s", err))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
|
||||
httpServerConfig := server.Config{Port: defSvcHTTPPort}
|
||||
if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err))
|
||||
|
||||
@@ -357,13 +357,13 @@ type addChildrenGroupsEvent struct {
|
||||
|
||||
func (acge addChildrenGroupsEvent) Encode() (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"operation": groupAddChildrenGroups,
|
||||
"id": acge.id,
|
||||
"childre_ids": acge.childrenIDs,
|
||||
"domain": acge.DomainID,
|
||||
"user_id": acge.UserID,
|
||||
"token_type": acge.Type.String(),
|
||||
"super_admin": acge.SuperAdmin,
|
||||
"operation": groupAddChildrenGroups,
|
||||
"id": acge.id,
|
||||
"children_ids": acge.childrenIDs,
|
||||
"domain": acge.DomainID,
|
||||
"user_id": acge.UserID,
|
||||
"token_type": acge.Type.String(),
|
||||
"super_admin": acge.SuperAdmin,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -142,9 +142,10 @@ type Service interface {
|
||||
// ViewGroup retrieves data about the group identified by ID.
|
||||
ViewGroup(ctx context.Context, session authn.Session, id string) (Group, error)
|
||||
|
||||
// ListGroups retrieves
|
||||
// ListGroups retrieves groups for given filters.
|
||||
ListGroups(ctx context.Context, session authn.Session, pm PageMeta) (Page, error)
|
||||
|
||||
// ListGroups retrieves user accessible groups for given filters.
|
||||
ListUserGroups(ctx context.Context, session authn.Session, userID string, pm PageMeta) (Page, error)
|
||||
|
||||
// EnableGroup logically enables the group identified with the provided ID.
|
||||
|
||||
@@ -73,6 +73,7 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups.
|
||||
}
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
repo: repo,
|
||||
authz: authz,
|
||||
opp: opp,
|
||||
extOpp: extOpp,
|
||||
|
||||
@@ -34,7 +34,8 @@ var (
|
||||
// ErrFailedToRetrieveAllGroups failed to retrieve groups.
|
||||
ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups")
|
||||
|
||||
ErrRoleMigration = errors.New("role migration initialization failed")
|
||||
// ErrRoleMigration failed to apply role migrations.
|
||||
ErrRoleMigration = errors.New("failed to apply role migration")
|
||||
|
||||
// ErrMissingNames indicates missing first and last names.
|
||||
ErrMissingNames = errors.New("missing first or last name")
|
||||
|
||||
@@ -43,6 +43,7 @@ type SubscriberConfig struct {
|
||||
Consumer string
|
||||
Stream string
|
||||
Handler EventHandler
|
||||
Ordered bool
|
||||
}
|
||||
|
||||
// Subscriber specifies event subscription API.
|
||||
|
||||
@@ -93,6 +93,7 @@ func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberCon
|
||||
logger: es.logger,
|
||||
},
|
||||
DeliveryPolicy: messaging.DeliverNewPolicy,
|
||||
Ordered: cfg.Ordered,
|
||||
}
|
||||
|
||||
return es.pubsub.Subscribe(ctx, subCfg)
|
||||
@@ -126,8 +127,9 @@ func (eh *eventHandler) Handle(msg *messaging.Message) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := eh.handler.Handle(eh.ctx, event); err != nil {
|
||||
eh.logger.Warn(fmt.Sprintf("failed to handle nats event: %s", err))
|
||||
err := eh.handler.Handle(eh.ctx, event)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to handle nats event: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/groups"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer"
|
||||
)
|
||||
|
||||
var (
|
||||
errDecodeCreateGroupEvent = errors.New("failed to decode group create event")
|
||||
errDecodeUpdateGroupEvent = errors.New("failed to decode group update event")
|
||||
errDecodeChangeStatusGroupEvent = errors.New("failed to decode group change status event")
|
||||
errDecodeRemoveGroupEvent = errors.New("failed to decode group remove event")
|
||||
errDecodeAddParentGroupEvent = errors.New("failed to decode group add parent event")
|
||||
errDecodeRemoveParentGroupEvent = errors.New("failed to decode group remove parent event")
|
||||
errDecodeAddChildrenGroupsEvent = errors.New("failed to decode group add children groups event")
|
||||
errDecodeRemoveChildrenGroupsEvent = errors.New("failed to decode group remove children groups event")
|
||||
|
||||
errID = errors.New("missing or invalid 'id'")
|
||||
errName = errors.New("missing or invalid 'name'")
|
||||
errDomain = errors.New("missing or invalid 'domain'")
|
||||
errParent = errors.New("missing or invalid 'parent'")
|
||||
errChildrenIDs = errors.New("missing or invalid 'children_ids'")
|
||||
errStatus = errors.New("missing or invalid 'status'")
|
||||
errConvertStatus = errors.New("failed to convert status")
|
||||
errCreatedAt = errors.New("failed to parse 'created_at' time")
|
||||
errUpdatedAt = errors.New("failed to parse 'updated_at' time")
|
||||
)
|
||||
|
||||
const (
|
||||
layout = "2006-01-02T15:04:05.999999Z"
|
||||
)
|
||||
|
||||
func ToGroups(data map[string]interface{}) (groups.Group, error) {
|
||||
var g groups.Group
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errID
|
||||
}
|
||||
g.ID = id
|
||||
|
||||
name, ok := data["name"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errName
|
||||
}
|
||||
g.Name = name
|
||||
|
||||
dom, ok := data["domain"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errDomain
|
||||
}
|
||||
g.Domain = dom
|
||||
|
||||
stat, ok := data["status"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errStatus
|
||||
}
|
||||
st, err := groups.ToStatus(stat)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errConvertStatus, err)
|
||||
}
|
||||
g.Status = st
|
||||
|
||||
cat, ok := data["created_at"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errCreatedAt
|
||||
}
|
||||
ct, err := time.Parse(layout, cat)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errCreatedAt, err)
|
||||
}
|
||||
g.CreatedAt = ct
|
||||
|
||||
// Following fields of groups are allowed to be empty.
|
||||
|
||||
desc, ok := data["description"].(string)
|
||||
if ok {
|
||||
g.Description = desc
|
||||
}
|
||||
|
||||
parent, ok := data["parent"].(string)
|
||||
if ok {
|
||||
g.Parent = parent
|
||||
}
|
||||
|
||||
meta, ok := data["metadata"].(map[string]interface{})
|
||||
if ok {
|
||||
g.Metadata = meta
|
||||
}
|
||||
|
||||
uby, ok := data["updated_by"].(string)
|
||||
if ok {
|
||||
g.UpdatedBy = uby
|
||||
}
|
||||
|
||||
uat, ok := data["updated_at"].(string)
|
||||
if ok {
|
||||
ut, err := time.Parse(layout, uat)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errUpdatedAt, err)
|
||||
}
|
||||
g.UpdatedAt = ut
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func decodeCreateGroupEvent(data map[string]interface{}) (groups.Group, []roles.RoleProvision, error) {
|
||||
g, err := ToGroups(data)
|
||||
if err != nil {
|
||||
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err)
|
||||
}
|
||||
irps, ok := data["roles_provisioned"].([]interface{})
|
||||
if !ok {
|
||||
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, errors.New("missing or invalid 'roles_provisioned'"))
|
||||
}
|
||||
rps, err := rconsumer.ToRoleProvisions(irps)
|
||||
if err != nil {
|
||||
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err)
|
||||
}
|
||||
|
||||
return g, rps, nil
|
||||
}
|
||||
|
||||
func decodeUpdateGroupEvent(data map[string]interface{}) (groups.Group, error) {
|
||||
g, err := ToGroups(data)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errDecodeUpdateGroupEvent, err)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func ToGroupStatus(data map[string]interface{}) (groups.Group, error) {
|
||||
var g groups.Group
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errID
|
||||
}
|
||||
g.ID = id
|
||||
|
||||
stat, ok := data["status"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errStatus
|
||||
}
|
||||
st, err := groups.ToStatus(stat)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errConvertStatus, err)
|
||||
}
|
||||
g.Status = st
|
||||
|
||||
uat, ok := data["updated_at"].(string)
|
||||
if ok {
|
||||
ut, err := time.Parse(layout, uat)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errUpdatedAt, err)
|
||||
}
|
||||
g.UpdatedAt = ut
|
||||
}
|
||||
|
||||
uby, ok := data["updated_by"].(string)
|
||||
if ok {
|
||||
g.UpdatedBy = uby
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func decodeChangeStatusGroupEvent(data map[string]interface{}) (groups.Group, error) {
|
||||
g, err := ToGroupStatus(data)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(errDecodeChangeStatusGroupEvent, err)
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func decodeRemoveGroupEvent(data map[string]interface{}) (groups.Group, error) {
|
||||
var g groups.Group
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return groups.Group{}, errors.Wrap(errDecodeRemoveGroupEvent, errID)
|
||||
}
|
||||
g.ID = id
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func decodeAddParentGroupEvent(data map[string]interface{}) (id string, parent string, err error) {
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return "", "", errors.Wrap(errAddParentGroupEvent, errID)
|
||||
}
|
||||
|
||||
parent, ok = data["parent_id"].(string)
|
||||
if !ok {
|
||||
return "", "", errors.Wrap(errDecodeAddParentGroupEvent, errParent)
|
||||
}
|
||||
|
||||
return id, parent, nil
|
||||
}
|
||||
|
||||
func decodeRemoveParentGroupEvent(data map[string]interface{}) (id string, err error) {
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return "", errors.Wrap(errDecodeRemoveParentGroupEvent, errID)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func decodeAddChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) {
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errID)
|
||||
}
|
||||
chIDs, ok := data["children_ids"].([]interface{})
|
||||
if !ok {
|
||||
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errChildrenIDs)
|
||||
}
|
||||
cids, err := rconsumer.ToStrings(chIDs)
|
||||
if err != nil {
|
||||
return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err))
|
||||
}
|
||||
return id, cids, nil
|
||||
}
|
||||
|
||||
func decodeRemoveChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) {
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID)
|
||||
}
|
||||
chIDs, ok := data["children_ids"].([]interface{})
|
||||
if !ok {
|
||||
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errChildrenIDs)
|
||||
}
|
||||
cids, err := rconsumer.ToStrings(chIDs)
|
||||
if err != nil {
|
||||
return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err))
|
||||
}
|
||||
return id, cids, nil
|
||||
}
|
||||
|
||||
func decodeRemoveAllChildrenGroupEvent(data map[string]interface{}) (id string, err error) {
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return "", errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package consumer contains events consumer for events
|
||||
// published by Bootstrap service.
|
||||
package consumer
|
||||
@@ -0,0 +1,253 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/supermq/groups"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
"github.com/absmach/supermq/pkg/events"
|
||||
"github.com/absmach/supermq/pkg/events/store"
|
||||
rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer"
|
||||
)
|
||||
|
||||
const (
|
||||
stream = "events.supermq.groups"
|
||||
|
||||
create = "group.create"
|
||||
update = "group.update"
|
||||
changeStatus = "group.change_status"
|
||||
remove = "group.remove"
|
||||
addParentGroup = "group.add_parent_group"
|
||||
removeParentGroup = "group.remove_parent_group"
|
||||
addChildrenGroups = "group.add_children_groups"
|
||||
removeChildrenGroups = "group.remove_children_groups"
|
||||
removeAllChildrenGroups = "group.remove_all_children_groups"
|
||||
addRole = "group.role.add"
|
||||
removeRole = "group.role.remove"
|
||||
updateRole = "group.role.update"
|
||||
addRoleActions = "group.role.actions.add"
|
||||
removeRoleActions = "group.role.actions.remove"
|
||||
removeAllRoleActions = "group.role.actions.remove_all"
|
||||
addRoleMembers = "group.role.members.add"
|
||||
removeRoleMembers = "group.role.members.remove"
|
||||
removeRoleAllMembers = "group.role.members.remove_all"
|
||||
removeMemberFromAllRoles = "group.role.members.remove_from_all_roles"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoOperationKey = errors.New("operation key is not found in event message")
|
||||
errCreateGroupEvent = errors.New("failed to consume group create event")
|
||||
errUpdateGroupEvent = errors.New("failed to consume group update event")
|
||||
errChangeStatusGroupEvent = errors.New("failed to consume group change status event")
|
||||
errRemoveGroupEvent = errors.New("failed to consume group remove event")
|
||||
errAddParentGroupEvent = errors.New("failed to consume group add parent group event")
|
||||
errRemoveParentGroupEvent = errors.New("failed to consume group remove parent group event")
|
||||
errAddChildrenGroupEvent = errors.New("failed to consume group add children groups event")
|
||||
errRemoveChildrenGroupEvent = errors.New("failed to consume group remove children groups event")
|
||||
errRemoveAllChildrenGroupEvent = errors.New("failed to consume group remove all children groups event")
|
||||
)
|
||||
|
||||
type eventHandler struct {
|
||||
repo groups.Repository
|
||||
rolesEventHandler rconsumer.EventHandler
|
||||
}
|
||||
|
||||
func GroupsEventsSubscribe(ctx context.Context, repo groups.Repository, esURL, esConsumerName string, logger *slog.Logger) error {
|
||||
subscriber, err := store.NewSubscriber(ctx, esURL, logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subConfig := events.SubscriberConfig{
|
||||
Stream: stream,
|
||||
Consumer: esConsumerName,
|
||||
Handler: NewEventHandler(repo),
|
||||
Ordered: true,
|
||||
}
|
||||
return subscriber.Subscribe(ctx, subConfig)
|
||||
}
|
||||
|
||||
// NewEventHandler returns new event store handler.
|
||||
func NewEventHandler(repo groups.Repository) events.EventHandler {
|
||||
reh := rconsumer.NewEventHandler("group", repo)
|
||||
return &eventHandler{
|
||||
repo: repo,
|
||||
rolesEventHandler: reh,
|
||||
}
|
||||
}
|
||||
|
||||
func (es *eventHandler) Handle(ctx context.Context, event events.Event) error {
|
||||
msg, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
op, ok := msg["operation"]
|
||||
|
||||
if !ok {
|
||||
return errNoOperationKey
|
||||
}
|
||||
switch op {
|
||||
case create:
|
||||
return es.createGroupHandler(ctx, msg)
|
||||
case update:
|
||||
return es.updateGroupHandler(ctx, msg)
|
||||
case changeStatus:
|
||||
return es.changeStatusGroupHandler(ctx, msg)
|
||||
case remove:
|
||||
return es.removeGroupHandler(ctx, msg)
|
||||
case addParentGroup:
|
||||
return es.addParentGroupHandler(ctx, msg)
|
||||
case removeParentGroup:
|
||||
return es.removeParentGroupHandler(ctx, msg)
|
||||
case addChildrenGroups:
|
||||
return es.addChildrenGroupsHandler(ctx, msg)
|
||||
case removeChildrenGroups:
|
||||
return es.removeChildrenGroupsHandler(ctx, msg)
|
||||
case removeAllChildrenGroups:
|
||||
return es.removeAllChildrenGroupsHandler(ctx, msg)
|
||||
case addRole:
|
||||
return es.rolesEventHandler.AddEntityRoleHandler(ctx, msg)
|
||||
case updateRole:
|
||||
return es.rolesEventHandler.UpdateEntityRoleHandler(ctx, msg)
|
||||
case removeRole:
|
||||
return es.rolesEventHandler.RemoveEntityRoleHandler(ctx, msg)
|
||||
case addRoleActions:
|
||||
return es.rolesEventHandler.AddEntityRoleActionsHandler(ctx, msg)
|
||||
case removeRoleActions:
|
||||
return es.rolesEventHandler.RemoveEntityRoleActionsHandler(ctx, msg)
|
||||
case removeAllRoleActions:
|
||||
return es.rolesEventHandler.RemoveAllEntityRoleActionsHandler(ctx, msg)
|
||||
case addRoleMembers:
|
||||
return es.rolesEventHandler.AddEntityRoleMembersHandler(ctx, msg)
|
||||
case removeRoleMembers:
|
||||
return es.rolesEventHandler.RemoveEntityRoleMembersHandler(ctx, msg)
|
||||
case removeRoleAllMembers:
|
||||
return es.rolesEventHandler.RemoveAllEntityRoleMembersHandler(ctx, msg)
|
||||
case removeMemberFromAllRoles:
|
||||
return es.rolesEventHandler.RemoveMemberFromAllEntityHandler(ctx, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) createGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
g, rps, err := decodeCreateGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errCreateGroupEvent, err)
|
||||
}
|
||||
|
||||
if _, err := es.repo.Save(ctx, g); err != nil {
|
||||
return errors.Wrap(errCreateGroupEvent, err)
|
||||
}
|
||||
if _, err := es.repo.AddRoles(ctx, rps); err != nil {
|
||||
return errors.Wrap(errCreateGroupEvent, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) updateGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
g, err := decodeUpdateGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errUpdateGroupEvent, err)
|
||||
}
|
||||
|
||||
if _, err := es.repo.Update(ctx, g); err != nil {
|
||||
return errors.Wrap(errUpdateGroupEvent, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) changeStatusGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
g, err := decodeChangeStatusGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errChangeStatusGroupEvent, err)
|
||||
}
|
||||
|
||||
if _, err := es.repo.ChangeStatus(ctx, g); err != nil {
|
||||
return errors.Wrap(errChangeStatusGroupEvent, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) removeGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
g, err := decodeRemoveGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemoveGroupEvent, err)
|
||||
}
|
||||
|
||||
if err := es.repo.Delete(ctx, g.ID); err != nil {
|
||||
return errors.Wrap(errRemoveGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) addParentGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, parent, err := decodeAddParentGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAddParentGroupEvent, err)
|
||||
}
|
||||
if err := es.repo.AssignParentGroup(ctx, parent, id); err != nil {
|
||||
return errors.Wrap(errAddParentGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, err := decodeRemoveParentGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemoveParentGroupEvent, err)
|
||||
}
|
||||
g, err := es.repo.RetrieveByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemoveParentGroupEvent, err)
|
||||
}
|
||||
fmt.Println(g, g.Parent, g.ID)
|
||||
if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil {
|
||||
return errors.Wrap(errRemoveParentGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) addChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, cids, err := decodeAddChildrenGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errAddChildrenGroupEvent, err)
|
||||
}
|
||||
|
||||
if err := es.repo.AssignParentGroup(ctx, id, cids...); err != nil {
|
||||
return errors.Wrap(errAddChildrenGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) removeChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, cids, err := decodeRemoveChildrenGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemoveChildrenGroupEvent, err)
|
||||
}
|
||||
|
||||
if err := es.repo.UnassignParentGroup(ctx, id, cids...); err != nil {
|
||||
return errors.Wrap(errRemoveChildrenGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventHandler) removeAllChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, err := decodeRemoveAllChildrenGroupEvent(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemoveAllChildrenGroupEvent, err)
|
||||
}
|
||||
if err := es.repo.UnassignAllChildrenGroups(ctx, id); err != nil && err != repoerr.ErrNotFound {
|
||||
return errors.Wrap(errRemoveAllChildrenGroupEvent, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package events provides the events sourcing of groups to
|
||||
// provide listing in clients and channels concept definitions needed to support
|
||||
package events
|
||||
@@ -94,7 +94,7 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig)
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
|
||||
nh := ps.natsHandler(cfg.Handler)
|
||||
nh := ps.natsHandler(cfg.Handler, cfg.AckErr)
|
||||
|
||||
consumerConfig := jetstream.ConsumerConfig{
|
||||
Name: formatConsumerName(cfg.Topic, cfg.ID),
|
||||
@@ -104,6 +104,10 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig)
|
||||
FilterSubject: cfg.Topic,
|
||||
}
|
||||
|
||||
if cfg.Ordered {
|
||||
consumerConfig.MaxAckPending = 1
|
||||
}
|
||||
|
||||
switch cfg.DeliveryPolicy {
|
||||
case messaging.DeliverNewPolicy:
|
||||
consumerConfig.DeliverPolicy = jetstream.DeliverNewPolicy
|
||||
@@ -140,17 +144,22 @@ func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *pubsub) natsHandler(h messaging.MessageHandler) func(m jetstream.Msg) {
|
||||
func (ps *pubsub) natsHandler(h messaging.MessageHandler, ackErr bool) func(m jetstream.Msg) {
|
||||
return func(m jetstream.Msg) {
|
||||
var msg messaging.Message
|
||||
if err := proto.Unmarshal(m.Data(), &msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Handle(&msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to handle SuperMQ message: %s", err))
|
||||
if ackErr {
|
||||
if err := m.Ack(); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := m.Ack(); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
|
||||
|
||||
@@ -39,6 +39,8 @@ type SubscriberConfig struct {
|
||||
Topic string
|
||||
Handler MessageHandler
|
||||
DeliveryPolicy DeliveryPolicy
|
||||
Ordered bool
|
||||
AckErr bool
|
||||
}
|
||||
|
||||
// Subscriber specifies message subscription API.
|
||||
|
||||
@@ -29,24 +29,24 @@ func Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName string)
|
||||
updated_at TIMESTAMP,
|
||||
updated_by VARCHAR(254),
|
||||
created_by VARCHAR(254),
|
||||
CONSTRAINT unique_role_name_entity_id_constraint UNIQUE ( name, entity_id),
|
||||
CONSTRAINT fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE
|
||||
);`, rolesTableNamePrefix, entityTableName, entityIDColumnName),
|
||||
CONSTRAINT %s_roles_unique_role_name_entity_id_constraint UNIQUE ( name, entity_id),
|
||||
CONSTRAINT %s_roles_fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE
|
||||
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, entityTableName, entityIDColumnName),
|
||||
|
||||
fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_actions (
|
||||
role_id VARCHAR(254) NOT NULL,
|
||||
action VARCHAR(254) NOT NULL,
|
||||
CONSTRAINT unique_domain_role_action_constraint UNIQUE ( role_id, action),
|
||||
CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
|
||||
CONSTRAINT %s_role_actions_unique_domain_role_action_constraint UNIQUE ( role_id, action),
|
||||
CONSTRAINT %s_role_actions_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
|
||||
|
||||
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
|
||||
fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_members (
|
||||
role_id VARCHAR(254) NOT NULL,
|
||||
member_id VARCHAR(254) NOT NULL,
|
||||
CONSTRAINT unique_role_member_constraint UNIQUE (role_id, member_id),
|
||||
CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
|
||||
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
CONSTRAINT %s_role_members_unique_role_member_constraint UNIQUE (role_id, member_id),
|
||||
CONSTRAINT %s_role_members_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE
|
||||
);`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix),
|
||||
},
|
||||
Down: []string{
|
||||
fmt.Sprintf(`DROP TABLE IF EXISTS %s_roles`, rolesTableNamePrefix),
|
||||
|
||||
@@ -406,7 +406,7 @@ func (repo *Repository) RoleAddActions(ctx context.Context, role roles.Role, act
|
||||
return []string{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return repo.RoleListActions(ctx, role.ID)
|
||||
return actions, nil
|
||||
}
|
||||
|
||||
func (repo *Repository) RoleListActions(ctx context.Context, roleID string) ([]string, error) {
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
var (
|
||||
errID = errors.New("missing or invalid 'id'")
|
||||
errRoleID = errors.New("missing or invalid 'role_id'")
|
||||
errName = errors.New("missing or invalid 'name'")
|
||||
errEntityID = errors.New("missing or invalid 'entity_id'")
|
||||
errActions = errors.New("missing or invalid 'actions'")
|
||||
errMembers = errors.New("missing or invalid 'members'")
|
||||
errCreatedAt = errors.New("failed to parse 'created_at' time")
|
||||
errUpdatedAt = errors.New("failed to parse 'updated_at' time")
|
||||
errNotString = errors.New("not string type")
|
||||
|
||||
errInvalidRoleProvision = errors.New("invalid 'role_provisions'")
|
||||
errRoleProvision = errors.New("failed to convert role_provisions interface'")
|
||||
errRoleProvisionMembers = errors.New("failed to convert role_provisions member interface'")
|
||||
errRoleProvisionActions = errors.New("failed to convert role_provisions action interface'")
|
||||
)
|
||||
|
||||
const (
|
||||
layout = "2006-01-02T15:04:05.999999Z"
|
||||
)
|
||||
|
||||
func ToRole(data map[string]interface{}) (roles.Role, error) {
|
||||
var r roles.Role
|
||||
|
||||
id, ok := data["id"].(string)
|
||||
if !ok {
|
||||
return roles.Role{}, errID
|
||||
}
|
||||
r.ID = id
|
||||
|
||||
name, ok := data["name"].(string)
|
||||
if !ok {
|
||||
return roles.Role{}, errName
|
||||
}
|
||||
r.Name = name
|
||||
|
||||
eid, ok := data["entity_id"].(string)
|
||||
if !ok {
|
||||
return roles.Role{}, errEntityID
|
||||
}
|
||||
r.EntityID = eid
|
||||
|
||||
// Following fields of groups are allowed to be empty.
|
||||
|
||||
cat, ok := data["created_at"].(string)
|
||||
if ok {
|
||||
ct, err := time.Parse(layout, cat)
|
||||
if err != nil {
|
||||
return roles.Role{}, errors.Wrap(errCreatedAt, err)
|
||||
}
|
||||
r.CreatedAt = ct
|
||||
}
|
||||
|
||||
cby, ok := data["created_by"].(string)
|
||||
if ok {
|
||||
r.CreatedBy = cby
|
||||
}
|
||||
|
||||
uat, ok := data["updated_at"].(string)
|
||||
if ok {
|
||||
ut, err := time.Parse(layout, uat)
|
||||
if err != nil {
|
||||
return roles.Role{}, errors.Wrap(errUpdatedAt, err)
|
||||
}
|
||||
r.UpdatedAt = ut
|
||||
}
|
||||
|
||||
uby, ok := data["updated_by"].(string)
|
||||
if ok {
|
||||
r.UpdatedBy = uby
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func ToStrings(data []interface{}) ([]string, error) {
|
||||
var strs []string
|
||||
for _, i := range data {
|
||||
str, ok := i.(string)
|
||||
if !ok {
|
||||
return []string{}, errNotString
|
||||
}
|
||||
strs = append(strs, str)
|
||||
}
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
func ToRoleProvision(data map[string]interface{}) (roles.RoleProvision, error) {
|
||||
var rp roles.RoleProvision
|
||||
|
||||
r, err := ToRole(data)
|
||||
if err != nil {
|
||||
return roles.RoleProvision{}, err
|
||||
}
|
||||
rp.Role = r
|
||||
|
||||
// Following fields of groups are allowed to be empty.
|
||||
|
||||
opActs, ok := data["optional_actions"].([]interface{})
|
||||
if ok {
|
||||
a, err := ToStrings(opActs)
|
||||
if err != nil {
|
||||
return roles.RoleProvision{}, errors.Wrap(errRoleProvisionActions, err)
|
||||
}
|
||||
rp.OptionalActions = a
|
||||
}
|
||||
|
||||
opMems, ok := data["optional_members"].([]interface{})
|
||||
if ok {
|
||||
m, err := ToStrings(opMems)
|
||||
if err != nil {
|
||||
return roles.RoleProvision{}, errors.Wrap(errRoleProvisionMembers, err)
|
||||
}
|
||||
rp.OptionalMembers = m
|
||||
}
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
||||
func ToRoleProvisions(data []interface{}) ([]roles.RoleProvision, error) {
|
||||
var rps []roles.RoleProvision
|
||||
for _, d := range data {
|
||||
irp, ok := d.(map[string]interface{})
|
||||
if !ok {
|
||||
return []roles.RoleProvision{}, errInvalidRoleProvision
|
||||
}
|
||||
rp, err := ToRoleProvision(irp)
|
||||
if err != nil {
|
||||
return []roles.RoleProvision{}, errors.Wrap(errRoleProvision, err)
|
||||
}
|
||||
rps = append(rps, rp)
|
||||
}
|
||||
return rps, nil
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package consumer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
const (
|
||||
errAddEntityRoleEvent = "failed to consume %s add role event : %w"
|
||||
errUpdateEntityRoleEvent = "failed to consume %s update role event : %w"
|
||||
errRemoveEntityRoleEvent = "failed to consume %s remove role event : %w"
|
||||
errAddEntityRoleActionsEvent = "failed to consume %s add role actions event : %w"
|
||||
errRemoveEntityRoleActionsEvent = "failed to consume %s remove role actions event : %w"
|
||||
errRemoveEntityRoleAllActionsEvent = "failed to consume %s remove role all actions event : %w"
|
||||
errAddEntityRoleMembersEvent = "failed to consume %s add role members event : %w"
|
||||
errRemoveEntityRoleMembersEvent = "failed to consume %s remove role members event : %w"
|
||||
errRemoveEntityRoleAllMembersEvent = "failed to consume %s remove role all members event : %w"
|
||||
)
|
||||
|
||||
type EventHandler struct {
|
||||
entityType string
|
||||
repo roles.Repository
|
||||
}
|
||||
|
||||
func NewEventHandler(entityType string, repo roles.Repository) EventHandler {
|
||||
return EventHandler{
|
||||
entityType: entityType,
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func (es *EventHandler) AddEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
rps, err := ToRoleProvision(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err)
|
||||
}
|
||||
if _, err := es.repo.AddRoles(ctx, []roles.RoleProvision{rps}); err != nil {
|
||||
if !errors.Contains(err, repoerr.ErrConflict) {
|
||||
return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) UpdateEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
ro, err := ToRole(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
if _, err = es.repo.UpdateRole(ctx, ro); err != nil {
|
||||
return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveEntityRoleHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, errRoleID)
|
||||
}
|
||||
|
||||
if err := es.repo.RemoveRoles(ctx, []string{id}); err != nil {
|
||||
return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) AddEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID)
|
||||
}
|
||||
iacts, ok := data["actions"].([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions)
|
||||
}
|
||||
acts, err := ToStrings(iacts)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
if _, err := es.repo.RoleAddActions(ctx, roles.Role{ID: id}, acts); err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID)
|
||||
}
|
||||
iacts, ok := data["actions"].([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions)
|
||||
}
|
||||
acts, err := ToStrings(iacts)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
if err := es.repo.RoleRemoveActions(ctx, roles.Role{ID: id}, acts); err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveAllEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, errRoleID)
|
||||
}
|
||||
|
||||
if err := es.repo.RoleRemoveAllActions(ctx, roles.Role{ID: id}); err != nil {
|
||||
return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) AddEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errRoleID)
|
||||
}
|
||||
imems, ok := data["members"].([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errMembers)
|
||||
}
|
||||
mems, err := ToStrings(imems)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
if _, err := es.repo.RoleAddMembers(ctx, roles.Role{ID: id}, mems); err != nil {
|
||||
return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errRoleID)
|
||||
}
|
||||
imems, ok := data["members"].([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errMembers)
|
||||
}
|
||||
mems, err := ToStrings(imems)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
if err := es.repo.RoleRemoveMembers(ctx, roles.Role{ID: id}, mems); err != nil {
|
||||
return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveAllEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
id, ok := data["role_id"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, errRoleID)
|
||||
}
|
||||
|
||||
if err := es.repo.RoleRemoveAllMembers(ctx, roles.Role{ID: id}); err != nil {
|
||||
return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *EventHandler) RemoveMemberFromAllEntityHandler(ctx context.Context, data map[string]interface{}) error {
|
||||
return nil
|
||||
}
|
||||
@@ -24,9 +24,10 @@ type RoleManagerEventStore struct {
|
||||
// events to event store.
|
||||
func NewRoleManagerEventStore(svcName, operationPrefix string, svc roles.RoleManager, publisher events.Publisher) RoleManagerEventStore {
|
||||
return RoleManagerEventStore{
|
||||
svcName: svcName,
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
svcName: svcName,
|
||||
operationPrefix: operationPrefix,
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+68
-43
@@ -392,9 +392,11 @@ func TestListChannels(t *testing.T) {
|
||||
offset: offset,
|
||||
total: total,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
},
|
||||
svcRes: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
@@ -417,8 +419,11 @@ func TestListChannels(t *testing.T) {
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
},
|
||||
svcRes: channels.Page{},
|
||||
authenticateErr: svcerr.ErrAuthentication,
|
||||
@@ -426,16 +431,20 @@ func TestListChannels(t *testing.T) {
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "list channels with empty token",
|
||||
token: "",
|
||||
domainID: validID,
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
channelsPageMeta: channels.PageMetadata{},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
|
||||
desc: "list channels with empty token",
|
||||
token: "",
|
||||
domainID: validID,
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "list channels with zero limit",
|
||||
@@ -444,9 +453,11 @@ func TestListChannels(t *testing.T) {
|
||||
offset: offset,
|
||||
limit: 0,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: offset,
|
||||
Limit: 10,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: offset,
|
||||
Limit: 10,
|
||||
},
|
||||
svcRes: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
@@ -464,16 +475,20 @@ func TestListChannels(t *testing.T) {
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list channels with limit greater than max",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
offset: offset,
|
||||
limit: 110,
|
||||
channelsPageMeta: channels.PageMetadata{},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
desc: "list channels with limit greater than max",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
offset: offset,
|
||||
limit: 110,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list channels with level",
|
||||
@@ -483,9 +498,11 @@ func TestListChannels(t *testing.T) {
|
||||
limit: 1,
|
||||
level: 1,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: offset,
|
||||
Limit: 1,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: offset,
|
||||
Limit: 1,
|
||||
},
|
||||
svcRes: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
@@ -510,10 +527,12 @@ func TestListChannels(t *testing.T) {
|
||||
limit: 10,
|
||||
metadata: sdk.Metadata{"name": "client_89"},
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: offset,
|
||||
Limit: 10,
|
||||
Permission: defPermission,
|
||||
Metadata: clients.Metadata{"name": "client_89"},
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: offset,
|
||||
Limit: 10,
|
||||
Metadata: clients.Metadata{"name": "client_89"},
|
||||
},
|
||||
svcRes: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
@@ -539,11 +558,15 @@ func TestListChannels(t *testing.T) {
|
||||
metadata: sdk.Metadata{
|
||||
"test": make(chan int),
|
||||
},
|
||||
channelsPageMeta: channels.PageMetadata{},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: channels.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list channels with service response that can't be unmarshalled",
|
||||
@@ -552,9 +575,11 @@ func TestListChannels(t *testing.T) {
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
channelsPageMeta: channels.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
svcRes: channels.Page{
|
||||
PageMetadata: channels.PageMetadata{
|
||||
|
||||
@@ -252,23 +252,6 @@ func (sdk mgSDK) DeleteClient(id, domainID, token string) errors.SDKError {
|
||||
return sdkerr
|
||||
}
|
||||
|
||||
func (sdk mgSDK) ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError) {
|
||||
url, err := sdk.withQueryParams(sdk.clientsURL, fmt.Sprintf("%s/%s/%s/%s", domainID, usersEndpoint, userID, clientsEndpoint), pm)
|
||||
if err != nil {
|
||||
return ClientsPage{}, errors.NewSDKError(err)
|
||||
}
|
||||
_, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK)
|
||||
if sdkerr != nil {
|
||||
return ClientsPage{}, sdkerr
|
||||
}
|
||||
cp := ClientsPage{}
|
||||
if err := json.Unmarshal(body, &cp); err != nil {
|
||||
return ClientsPage{}, errors.NewSDKError(err)
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (sdk mgSDK) CreateClientRole(id, domainID string, rq RoleReq, token string) (Role, errors.SDKError) {
|
||||
return sdk.createRole(sdk.clientsURL, clientsEndpoint, id, domainID, rq, token)
|
||||
}
|
||||
|
||||
+44
-275
@@ -357,9 +357,11 @@ func TestListClients(t *testing.T) {
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
@@ -387,9 +389,11 @@ func TestListClients(t *testing.T) {
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcRes: clients.ClientsPage{},
|
||||
authenticateErr: svcerr.ErrAuthentication,
|
||||
@@ -404,7 +408,11 @@ func TestListClients(t *testing.T) {
|
||||
Offset: 0,
|
||||
Limit: 1000,
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcReq: clients.Page{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
@@ -419,7 +427,11 @@ func TestListClients(t *testing.T) {
|
||||
Limit: 100,
|
||||
Name: strings.Repeat("a", 1025),
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcReq: clients.Page{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
@@ -435,10 +447,12 @@ func TestListClients(t *testing.T) {
|
||||
Status: clients.DisabledStatus.String(),
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Status: clients.DisabledStatus,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: clients.DisabledStatus,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
@@ -468,10 +482,12 @@ func TestListClients(t *testing.T) {
|
||||
Tag: "tag1",
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Tag: "tag1",
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Tag: "tag1",
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
@@ -502,7 +518,11 @@ func TestListClients(t *testing.T) {
|
||||
"test": make(chan int),
|
||||
},
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcReq: clients.Page{
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
@@ -517,9 +537,11 @@ func TestListClients(t *testing.T) {
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Actions: []string{},
|
||||
Order: "updated_at",
|
||||
Dir: "asc",
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
@@ -547,12 +569,12 @@ func TestListClients(t *testing.T) {
|
||||
tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
|
||||
}
|
||||
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
|
||||
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq).Return(tc.svcRes, tc.svcErr)
|
||||
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := mgsdk.Clients(tc.pageMeta, tc.domainID, tc.token)
|
||||
assert.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.response, resp)
|
||||
if tc.err == nil {
|
||||
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq)
|
||||
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.svcReq)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
svcCall.Unset()
|
||||
@@ -1400,259 +1422,6 @@ func TestDeleteClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserClients(t *testing.T) {
|
||||
ts, tsvc, auth := setupClients()
|
||||
defer ts.Close()
|
||||
|
||||
var sdkClients []sdk.Client
|
||||
for i := 10; i < 100; i++ {
|
||||
c := generateTestClient(t)
|
||||
if i == 50 {
|
||||
c.Status = clients.DisabledStatus.String()
|
||||
c.Tags = []string{"tag1", "tag2"}
|
||||
}
|
||||
sdkClients = append(sdkClients, c)
|
||||
}
|
||||
|
||||
conf := sdk.Config{
|
||||
ClientsURL: ts.URL,
|
||||
}
|
||||
mgsdk := sdk.NewSDK(conf)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
session smqauthn.Session
|
||||
userID string
|
||||
domainID string
|
||||
pageMeta sdk.PageMetadata
|
||||
svcReq clients.Page
|
||||
svcRes clients.ClientsPage
|
||||
svcErr error
|
||||
authenticateErr error
|
||||
response sdk.ClientsPage
|
||||
err errors.SDKError
|
||||
}{
|
||||
{
|
||||
desc: "list user clients successfully",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Total: uint64(len(sdkClients)),
|
||||
},
|
||||
Clients: convertClients(sdkClients...),
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{
|
||||
PageRes: sdk.PageRes{
|
||||
Limit: 100,
|
||||
Total: uint64(len(sdkClients)),
|
||||
},
|
||||
Clients: sdkClients,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "list user clients with an invalid token",
|
||||
token: invalidToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
},
|
||||
svcRes: clients.ClientsPage{},
|
||||
authenticateErr: svcerr.ErrAuthentication,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "list user clients with limit greater than max",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 1000,
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list user clients with name size greater than max",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Name: strings.Repeat("a", 1025),
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list user clients with status",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Status: clients.DisabledStatus.String(),
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Status: clients.DisabledStatus,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Total: 1,
|
||||
},
|
||||
Clients: convertClients(sdkClients[50]),
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{
|
||||
PageRes: sdk.PageRes{
|
||||
Limit: 100,
|
||||
Total: 1,
|
||||
},
|
||||
Clients: []sdk.Client{sdkClients[50]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user clients with tags",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Tag: "tag1",
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
Tag: "tag1",
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Total: 1,
|
||||
},
|
||||
Clients: convertClients(sdkClients[50]),
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{
|
||||
PageRes: sdk.PageRes{
|
||||
Limit: 100,
|
||||
Total: 1,
|
||||
},
|
||||
Clients: []sdk.Client{sdkClients[50]},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user clients with invalid metadata",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Metadata: map[string]interface{}{
|
||||
"test": make(chan int),
|
||||
},
|
||||
},
|
||||
svcReq: clients.Page{},
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list user clients with response that can't be unmarshalled",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
pageMeta: sdk.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
svcReq: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Permission: defPermission,
|
||||
},
|
||||
svcRes: clients.ClientsPage{
|
||||
Page: clients.Page{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Total: 1,
|
||||
},
|
||||
Clients: []clients.Client{{
|
||||
Name: sdkClients[0].Name,
|
||||
Tags: sdkClients[0].Tags,
|
||||
Credentials: clients.Credentials(sdkClients[0].Credentials),
|
||||
Metadata: clients.Metadata{
|
||||
"test": make(chan int),
|
||||
},
|
||||
}},
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
|
||||
}
|
||||
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
|
||||
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := mgsdk.ListUserClients(tc.userID, tc.domainID, tc.pageMeta, tc.token)
|
||||
assert.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.response, resp)
|
||||
if tc.err == nil {
|
||||
ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetClientParent(t *testing.T) {
|
||||
ts, csvc, auth := setupClients()
|
||||
defer ts.Close()
|
||||
|
||||
@@ -4393,67 +4393,6 @@ func (_c *SDK_ListDomainUsers_Call) RunAndReturn(run func(string, sdk.PageMetada
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListUserClients provides a mock function with given fields: userID, domainID, pm, token
|
||||
func (_m *SDK) ListUserClients(userID string, domainID string, pm sdk.PageMetadata, token string) (sdk.ClientsPage, errors.SDKError) {
|
||||
ret := _m.Called(userID, domainID, pm, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserClients")
|
||||
}
|
||||
|
||||
var r0 sdk.ClientsPage
|
||||
var r1 errors.SDKError
|
||||
if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)); ok {
|
||||
return rf(userID, domainID, pm, token)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) sdk.ClientsPage); ok {
|
||||
r0 = rf(userID, domainID, pm, token)
|
||||
} else {
|
||||
r0 = ret.Get(0).(sdk.ClientsPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, string, sdk.PageMetadata, string) errors.SDKError); ok {
|
||||
r1 = rf(userID, domainID, pm, token)
|
||||
} else {
|
||||
if ret.Get(1) != nil {
|
||||
r1 = ret.Get(1).(errors.SDKError)
|
||||
}
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// SDK_ListUserClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserClients'
|
||||
type SDK_ListUserClients_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListUserClients is a helper method to define mock.On call
|
||||
// - userID string
|
||||
// - domainID string
|
||||
// - pm sdk.PageMetadata
|
||||
// - token string
|
||||
func (_e *SDK_Expecter) ListUserClients(userID interface{}, domainID interface{}, pm interface{}, token interface{}) *SDK_ListUserClients_Call {
|
||||
return &SDK_ListUserClients_Call{Call: _e.mock.On("ListUserClients", userID, domainID, pm, token)}
|
||||
}
|
||||
|
||||
func (_c *SDK_ListUserClients_Call) Run(run func(userID string, domainID string, pm sdk.PageMetadata, token string)) *SDK_ListUserClients_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(string), args[2].(sdk.PageMetadata), args[3].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_ListUserClients_Call) Return(_a0 sdk.ClientsPage, _a1 errors.SDKError) *SDK_ListUserClients_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_ListUserClients_Call) RunAndReturn(run func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)) *SDK_ListUserClients_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Members provides a mock function with given fields: groupID, domainID, pm, token
|
||||
func (_m *SDK) Members(groupID string, domainID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) {
|
||||
ret := _m.Called(groupID, domainID, pm, token)
|
||||
|
||||
@@ -515,17 +515,6 @@ type SDK interface {
|
||||
// fmt.Println(err)
|
||||
RemoveClientParent(id, domainID, groupID, token string) errors.SDKError
|
||||
|
||||
// ListUserClients returns list of clients for the given user ID and filters.
|
||||
//
|
||||
// example:
|
||||
// pm := sdk.PageMetadata{
|
||||
// Offset: 0,
|
||||
// Limit: 10,
|
||||
// }
|
||||
// clients, _ := sdk.ListUserClients("userID", "domainID", pm,"token")
|
||||
// fmt.Println(clients)
|
||||
ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError)
|
||||
|
||||
// CreateClientRole creates new client role and returns its id.
|
||||
//
|
||||
// example:
|
||||
|
||||
@@ -216,7 +216,6 @@ func convertChannel(g sdk.Channel) mgchannels.Channel {
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
UpdatedBy: g.UpdatedBy,
|
||||
Status: status,
|
||||
Permissions: g.Permissions,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user