NOISSUE - Update changes to domains and channels (#2752)

Signed-off-by: ianmuchyri <ianmuchiri8@gmail.com>
This commit is contained in:
Ian Ngethe Muchiri
2025-03-12 15:31:03 +03:00
committed by GitHub
parent e7032a6313
commit 065e764387
39 changed files with 1001 additions and 550 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
// Code generated by mockery v2.52.3. DO NOT EDIT.
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
+2 -2
View File
@@ -12,10 +12,10 @@ import (
grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
"github.com/absmach/supermq/channels"
ch "github.com/absmach/supermq/channels"
grpcapi "github.com/absmach/supermq/channels/api/grpc"
"github.com/absmach/supermq/channels/private/mocks"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -35,7 +35,7 @@ var (
validChannel = ch.Channel{
ID: validID,
Domain: testsutil.GenerateUUID(&testing.T{}),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
}
)
+25 -17
View File
@@ -11,7 +11,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
)
@@ -65,7 +65,7 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
status, err := clients.ToStatus(s)
status, err := channels.ToStatus(s)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
@@ -135,22 +135,30 @@ func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error)
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req := listChannelsReq{
name: name,
tag: tag,
status: status,
metadata: meta,
roleName: roleName,
roleID: roleID,
actions: actions,
accessType: accessType,
order: order,
dir: dir,
offset: offset,
limit: limit,
groupID: groupID,
clientID: clientID,
userID: userID,
Page: channels.Page{
Name: name,
Tag: tag,
Status: status,
Metadata: meta,
RoleName: roleName,
RoleID: roleID,
Actions: actions,
AccessType: accessType,
Order: order,
Dir: dir,
Offset: offset,
Limit: limit,
Group: groupID,
Client: clientID,
ID: id,
},
userID: userID,
}
return req, nil
}
+28 -29
View File
@@ -17,7 +17,6 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/channels/mocks"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
@@ -39,13 +38,13 @@ var (
Name: valid,
Domain: testsutil.GenerateUUID(&testing.T{}),
ParentGroup: testsutil.GenerateUUID(&testing.T{}),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"name": "test",
},
CreatedAt: time.Now().Add(-1 * time.Second),
UpdatedAt: time.Now(),
UpdatedBy: testsutil.GenerateUUID(&testing.T{}),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
}
validID = testsutil.GenerateUUID(&testing.T{})
validToken = "validToken"
@@ -439,7 +438,7 @@ func TestListChannels(t *testing.T) {
domainID string
token string
session smqauthn.Session
listChannelsResponse channels.Page
listChannelsResponse channels.ChannelsPage
status int
authnErr error
err error
@@ -449,8 +448,8 @@ func TestListChannels(t *testing.T) {
domainID: validID,
token: validToken,
status: http.StatusOK,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -476,8 +475,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with offset",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -498,8 +497,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with limit",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -528,8 +527,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with name",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -558,8 +557,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with status",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -588,8 +587,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with tags",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -618,8 +617,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with metadata",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -648,8 +647,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with permissions",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -678,8 +677,8 @@ func TestListChannels(t *testing.T) {
desc: "list channels with list perms",
domainID: validID,
token: validToken,
listChannelsResponse: channels.Page{
PageMetadata: channels.PageMetadata{
listChannelsResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{validChannelResp},
@@ -2032,11 +2031,11 @@ func toJSON(data interface{}) string {
}
type respBody struct {
Err string `json:"error"`
Message string `json:"message"`
Total int `json:"total"`
Permissions []string `json:"permissions"`
ID string `json:"id"`
Tags []string `json:"tags"`
Status clients.Status `json:"status"`
Err string `json:"error"`
Message string `json:"message"`
Total int `json:"total"`
Permissions []string `json:"permissions"`
ID string `json:"id"`
Tags []string `json:"tags"`
Status channels.Status `json:"status"`
}
+3 -21
View File
@@ -103,31 +103,13 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthentication
}
pm := channels.PageMetadata{
Offset: req.offset,
Limit: req.limit,
Name: req.name,
Order: req.order,
Dir: req.dir,
Metadata: req.metadata,
Tag: req.tag,
Status: req.status,
Group: req.groupID,
Client: req.clientID,
ConnectionType: req.connType,
RoleName: req.roleName,
RoleID: req.roleID,
Actions: req.actions,
AccessType: req.accessType,
}
var page channels.Page
var page channels.ChannelsPage
var err error
switch req.userID != "" {
case true:
page, err = svc.ListUserChannels(ctx, session, req.userID, pm)
page, err = svc.ListUserChannels(ctx, session, req.userID, req.Page)
default:
page, err = svc.ListChannels(ctx, session, pm)
page, err = svc.ListChannels(ctx, session, req.Page)
}
if err != nil {
return channelsPageRes{}, err
+4 -19
View File
@@ -9,7 +9,6 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/connections"
)
@@ -64,30 +63,16 @@ func (req viewChannelReq) validate() error {
}
type listChannelsReq struct {
name string
tag string
status clients.Status
metadata clients.Metadata
roleName string
roleID string
actions []string
accessType string
order string
dir string
offset uint64
limit uint64
groupID string
clientID string
connType string
userID string
channels.Page
userID string
}
func (req listChannelsReq) validate() error {
if req.limit > api.MaxLimitSize || req.limit < 1 {
if req.Limit > api.MaxLimitSize || req.Limit < 1 {
return apiutil.ErrLimitSize
}
if len(req.name) > api.MaxNameSize {
if len(req.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
+4 -5
View File
@@ -148,29 +148,28 @@ func TestListChannelsReqValidation(t *testing.T) {
{
desc: "valid request",
req: listChannelsReq{
limit: 10,
Page: channels.Page{Limit: 10},
},
err: nil,
},
{
desc: "limit is 0",
req: listChannelsReq{
limit: 0,
Page: channels.Page{Limit: 0},
},
err: apiutil.ErrLimitSize,
},
{
desc: "limit is greater than max limit",
req: listChannelsReq{
limit: api.MaxLimitSize + 1,
Page: channels.Page{Limit: api.MaxLimitSize + 1},
},
err: apiutil.ErrLimitSize,
},
{
desc: "name is too long",
req: listChannelsReq{
limit: 10,
name: strings.Repeat("a", api.MaxNameSize+1),
Page: channels.Page{Limit: 10, Name: strings.Repeat("a", api.MaxNameSize+1)},
},
err: apiutil.ErrNameSize,
},
+40 -38
View File
@@ -7,26 +7,28 @@ import (
"context"
"time"
clients "github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/roles"
)
// Metadata represents arbitrary JSON.
type Metadata map[string]interface{}
// Channel represents a SuperMQ "communication topic". This topic
// contains the clients that can exchange messages between each other.
type Channel struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Tags []string `json:"tags,omitempty"`
ParentGroup string `json:"parent_group_id,omitempty"`
Domain string `json:"domain_id,omitempty"`
Metadata clients.Metadata `json:"metadata,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
ID string `json:"id"`
Name string `json:"name,omitempty"`
Tags []string `json:"tags,omitempty"`
ParentGroup string `json:"parent_group_id,omitempty"`
Domain string `json:"domain_id,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
// Extended
ParentGroupPath string `json:"parent_group_path,omitempty"`
RoleID string `json:"role_id,omitempty"`
@@ -40,32 +42,32 @@ type Channel struct {
ConnectionTypes []connections.ConnType `json:"connection_types,omitempty"`
}
type PageMetadata struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata clients.Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status clients.Status `json:"status,omitempty"`
Group string `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
type Page struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tag string `json:"tag,omitempty"`
Status Status `json:"status,omitempty"`
Group string `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
}
// ChannelsPage contains page related metadata as well as list of channels that
// belong to this page.
type Page struct {
PageMetadata
type ChannelsPage struct {
Page
Channels []Channel
}
@@ -105,10 +107,10 @@ type Service interface {
DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
// ListChannels retrieves data about subset of channels that belongs to the user.
ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error)
ListChannels(ctx context.Context, session authn.Session, pm Page) (ChannelsPage, error)
// ListUserChannels retrieves data about subset of channels that belong to the specified user.
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error)
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm Page) (ChannelsPage, error)
// RemoveChannel removes the client identified by the provided ID, that
// belongs to the user.
@@ -144,13 +146,13 @@ type Repository interface {
ChangeStatus(ctx context.Context, channel Channel) (Channel, error)
// RetrieveUserChannels retrieves the channel of given domainID and userID.
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm PageMetadata) (Page, error)
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm Page) (ChannelsPage, error)
// RetrieveByID retrieves the channel having the provided identifier
RetrieveByID(ctx context.Context, id string) (Channel, error)
// RetrieveAll retrieves the subset of channels.
RetrieveAll(ctx context.Context, pm PageMetadata) (Page, error)
RetrieveAll(ctx context.Context, pm Page) (ChannelsPage, error)
// Remove removes the channel having the provided identifier
Remove(ctx context.Context, ids ...string) error
+17
View File
@@ -0,0 +1,17 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
import "errors"
var (
// ErrInvalidStatus indicates invalid status.
ErrInvalidStatus = errors.New("invalid channels status")
// ErrEnableChannel indicates error in enabling channel.
ErrEnableChannel = errors.New("failed to enable channel")
// ErrDisableChannel indicates error in disabling channel.
ErrDisableChannel = errors.New("failed to disable channel")
)
+2 -2
View File
@@ -183,7 +183,7 @@ func (vce viewChannelEvent) Encode() (map[string]interface{}, error) {
}
type listChannelEvent struct {
channels.PageMetadata
channels.Page
authn.Session
requestID string
}
@@ -228,7 +228,7 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) {
type listUserChannelsEvent struct {
userID string
channels.PageMetadata
channels.Page
authn.Session
requestID string
}
+9 -9
View File
@@ -114,15 +114,15 @@ func (es *eventStore) ViewChannel(ctx context.Context, session authn.Session, id
return chann, nil
}
func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
cp, err := es.svc.ListChannels(ctx, session, pm)
if err != nil {
return cp, err
}
event := listChannelEvent{
PageMetadata: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
Page: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, event); err != nil {
return cp, err
@@ -131,16 +131,16 @@ func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, p
return cp, nil
}
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
cp, err := es.svc.ListUserChannels(ctx, session, userID, pm)
if err != nil {
return cp, err
}
event := listUserChannelsEvent{
userID: userID,
PageMetadata: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
userID: userID,
Page: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, event); err != nil {
return cp, err
+5 -5
View File
@@ -148,7 +148,7 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
return am.svc.ViewChannel(ctx, session, id)
}
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
@@ -158,7 +158,7 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
return channels.ChannelsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
@@ -168,7 +168,7 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
return am.svc.ListChannels(ctx, session, pm)
}
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
@@ -178,11 +178,11 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
return channels.ChannelsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
return channels.Page{}, errors.Wrap(err, errList)
return channels.ChannelsPage{}, errors.Wrap(err, errList)
}
return am.svc.ListUserChannels(ctx, session, userID, pm)
}
+2 -2
View File
@@ -67,7 +67,7 @@ func (lm *loggingMiddleware) ViewChannel(ctx context.Context, session authn.Sess
return lm.svc.ViewChannel(ctx, session, id)
}
func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (cp channels.Page, err error) {
func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (cp channels.ChannelsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -89,7 +89,7 @@ func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Ses
return lm.svc.ListChannels(ctx, session, pm)
}
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (cp channels.Page, err error) {
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (cp channels.ChannelsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
+2 -2
View File
@@ -50,7 +50,7 @@ func (ms *metricsMiddleware) ViewChannel(ctx context.Context, session authn.Sess
return ms.svc.ViewChannel(ctx, session, id)
}
func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_channels").Add(1)
ms.latency.With("method", "list_channels").Observe(time.Since(begin).Seconds())
@@ -58,7 +58,7 @@ func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Ses
return ms.svc.ListChannels(ctx, session, pm)
}
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_user_channels").Add(1)
ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds())
+12 -12
View File
@@ -367,25 +367,25 @@ func (_m *Repository) RemoveRoles(ctx context.Context, roleIDs []string) error {
}
// RetrieveAll provides a mock function with given fields: ctx, pm
func (_m *Repository) RetrieveAll(ctx context.Context, pm channels.PageMetadata) (channels.Page, error) {
func (_m *Repository) RetrieveAll(ctx context.Context, pm channels.Page) (channels.ChannelsPage, error) {
ret := _m.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 channels.Page
var r0 channels.ChannelsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, channels.PageMetadata) (channels.Page, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, channels.Page) (channels.ChannelsPage, error)); ok {
return rf(ctx, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, channels.PageMetadata) channels.Page); ok {
if rf, ok := ret.Get(0).(func(context.Context, channels.Page) channels.ChannelsPage); ok {
r0 = rf(ctx, pm)
} else {
r0 = ret.Get(0).(channels.Page)
r0 = ret.Get(0).(channels.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, channels.PageMetadata) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, channels.Page) error); ok {
r1 = rf(ctx, pm)
} else {
r1 = ret.Error(1)
@@ -576,25 +576,25 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro
}
// RetrieveUserChannels provides a mock function with given fields: ctx, domainID, userID, pm
func (_m *Repository) RetrieveUserChannels(ctx context.Context, domainID string, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (_m *Repository) RetrieveUserChannels(ctx context.Context, domainID string, userID string, pm channels.Page) (channels.ChannelsPage, error) {
ret := _m.Called(ctx, domainID, userID, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveUserChannels")
}
var r0 channels.Page
var r0 channels.ChannelsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) (channels.Page, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.Page) (channels.ChannelsPage, error)); ok {
return rf(ctx, domainID, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) channels.Page); ok {
if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.Page) channels.ChannelsPage); ok {
r0 = rf(ctx, domainID, userID, pm)
} else {
r0 = ret.Get(0).(channels.Page)
r0 = ret.Get(0).(channels.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, channels.PageMetadata) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, string, string, channels.Page) error); ok {
r1 = rf(ctx, domainID, userID, pm)
} else {
r1 = ret.Error(1)
+12 -12
View File
@@ -219,25 +219,25 @@ func (_m *Service) ListAvailableActions(ctx context.Context, session authn.Sessi
}
// ListChannels provides a mock function with given fields: ctx, session, pm
func (_m *Service) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
func (_m *Service) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
ret := _m.Called(ctx, session, pm)
if len(ret) == 0 {
panic("no return value specified for ListChannels")
}
var r0 channels.Page
var r0 channels.ChannelsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, channels.PageMetadata) (channels.Page, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, channels.Page) (channels.ChannelsPage, error)); ok {
return rf(ctx, session, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, channels.PageMetadata) channels.Page); ok {
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, channels.Page) channels.ChannelsPage); ok {
r0 = rf(ctx, session, pm)
} else {
r0 = ret.Get(0).(channels.Page)
r0 = ret.Get(0).(channels.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, channels.PageMetadata) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, channels.Page) error); ok {
r1 = rf(ctx, session, pm)
} else {
r1 = ret.Error(1)
@@ -275,25 +275,25 @@ func (_m *Service) ListEntityMembers(ctx context.Context, session authn.Session,
}
// ListUserChannels provides a mock function with given fields: ctx, session, userID, pm
func (_m *Service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (_m *Service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
ret := _m.Called(ctx, session, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListUserChannels")
}
var r0 channels.Page
var r0 channels.ChannelsPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) (channels.Page, error)); ok {
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.Page) (channels.ChannelsPage, error)); ok {
return rf(ctx, session, userID, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) channels.Page); ok {
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.Page) channels.ChannelsPage); ok {
r0 = rf(ctx, session, userID, pm)
} else {
r0 = ret.Get(0).(channels.Page)
r0 = ret.Get(0).(channels.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, channels.PageMetadata) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, channels.Page) error); ok {
r1 = rf(ctx, session, userID, pm)
} else {
r1 = ret.Error(1)
+46 -47
View File
@@ -14,7 +14,6 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
clients "github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
@@ -102,7 +101,7 @@ func (cr *channelRepository) Update(ctx context.Context, channel channels.Channe
WHERE id = :id AND status = :status
RETURNING id, name, tags, metadata, COALESCE(domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, status, created_at, updated_at, updated_by`,
upq)
channel.Status = clients.EnabledStatus
channel.Status = channels.EnabledStatus
return cr.update(ctx, channel, q)
}
@@ -110,7 +109,7 @@ func (cr *channelRepository) UpdateTags(ctx context.Context, channel channels.Ch
q := `UPDATE channels SET tags = :tags, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, name, tags, metadata, COALESCE(domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, status, created_at, updated_at, updated_by`
channel.Status = clients.EnabledStatus
channel.Status = channels.EnabledStatus
return cr.update(ctx, channel, q)
}
@@ -146,10 +145,10 @@ func (cr *channelRepository) RetrieveByID(ctx context.Context, id string) (chann
return channels.Channel{}, repoerr.ErrNotFound
}
func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMetadata) (channels.Page, error) {
func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page) (channels.ChannelsPage, error) {
pageQuery, err := PageQuery(pm)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
connJoinQuery := `
@@ -204,11 +203,11 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
dbPage, err := toDBChannelsPage(pm)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
rows, err := cr.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
@@ -216,12 +215,12 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
for rows.Next() {
dbch := dbChannel{}
if err := rows.StructScan(&dbch); err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
ch, err := toChannel(dbch)
if err != nil {
return channels.Page{}, err
return channels.ChannelsPage{}, err
}
items = append(items, ch)
@@ -234,12 +233,12 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
total, err := postgres.Total(ctx, cr.db, cq, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := channels.Page{
page := channels.ChannelsPage{
Channels: items,
PageMetadata: channels.PageMetadata{
Page: channels.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
@@ -248,14 +247,14 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe
return page, nil
}
func (repo *channelRepository) RetrieveUserChannels(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
return repo.retrieveClients(ctx, domainID, userID, pm)
func (repo *channelRepository) RetrieveUserChannels(ctx context.Context, domainID, userID string, pm channels.Page) (channels.ChannelsPage, error) {
return repo.retrieveChannels(ctx, domainID, userID, pm)
}
func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, userID string, pm channels.Page) (channels.ChannelsPage, error) {
pageQuery, err := PageQuery(pm)
if err != nil {
return channels.Page{}, err
return channels.ChannelsPage{}, err
}
bq := repo.userChannelsBaseQuery(domainID, userID)
@@ -316,12 +315,12 @@ func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, us
dbPage, err := toDBChannelsPage(pm)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -329,12 +328,12 @@ func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, us
for rows.Next() {
dbc := dbChannel{}
if err := rows.StructScan(&dbc); err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := toChannel(dbc)
if err != nil {
return channels.Page{}, err
return channels.ChannelsPage{}, err
}
items = append(items, c)
@@ -371,12 +370,12 @@ func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, us
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := channels.Page{
page := channels.ChannelsPage{
Channels: items,
PageMetadata: channels.PageMetadata{
Page: channels.Page{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
@@ -880,7 +879,7 @@ type dbChannel struct {
CreatedAt time.Time `db:"created_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status clients.Status `db:"status,omitempty"`
Status channels.Status `db:"status,omitempty"`
ParentGroupPath string `db:"parent_group_path,omitempty"`
RoleID string `db:"role_id,omitempty"`
RoleName string `db:"role_name,omitempty"`
@@ -952,7 +951,7 @@ func toString(s sql.NullString) string {
}
func toChannel(ch dbChannel) (channels.Channel, error) {
var metadata clients.Metadata
var metadata channels.Metadata
if ch.Metadata != nil {
if err := json.Unmarshal([]byte(ch.Metadata), &metadata); err != nil {
return channels.Channel{}, errors.Wrap(errors.ErrMalformedEntity, err)
@@ -1011,7 +1010,7 @@ func toChannel(ch dbChannel) (channels.Channel, error) {
return newCh, nil
}
func PageQuery(pm channels.PageMetadata) (string, error) {
func PageQuery(pm channels.Page) (string, error) {
mq, _, err := postgres.CreateMetadataQuery("", pm.Metadata)
if err != nil {
return "", errors.Wrap(errors.ErrMalformedEntity, err)
@@ -1022,8 +1021,8 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
query = append(query, "c.name ILIKE '%' || :name || '%'")
}
if pm.Id != "" {
query = append(query, "c.id ILIKE '%' || :id || '%'")
if pm.ID != "" {
query = append(query, "c.id = :id")
}
if pm.Tag != "" {
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
@@ -1036,7 +1035,7 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
if len(pm.IDs) != 0 {
query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','")))
}
if pm.Status != clients.AllStatus {
if pm.Status != channels.AllStatus {
query = append(query, "c.status = :status")
}
if pm.Domain != "" {
@@ -1074,7 +1073,7 @@ func PageQuery(pm channels.PageMetadata) (string, error) {
return emq, nil
}
func applyOrdering(emq string, pm channels.PageMetadata) string {
func applyOrdering(emq string, pm channels.Page) string {
switch pm.Order {
case "name", "created_at", "updated_at":
emq = fmt.Sprintf("%s ORDER BY %s", emq, pm.Order)
@@ -1090,7 +1089,7 @@ func applyLimitOffset(query string) string {
LIMIT :limit OFFSET :offset`, query)
}
func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) {
func toDBChannelsPage(pm channels.Page) (dbChannelsPage, error) {
_, data, err := postgres.CreateMetadataQuery("", pm.Metadata)
if err != nil {
return dbChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
@@ -1099,7 +1098,7 @@ func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) {
Limit: pm.Limit,
Offset: pm.Offset,
Name: pm.Name,
Id: pm.Id,
Id: pm.ID,
Domain: pm.Domain,
Metadata: data,
Tag: pm.Tag,
@@ -1115,21 +1114,21 @@ func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) {
}
type dbChannelsPage struct {
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status clients.Status `db:"status"`
GroupID string `db:"group_id"`
ClientID string `db:"client_id"`
ConnType string `db:"type"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
Name string `db:"name"`
Id string `db:"id"`
Domain string `db:"domain_id"`
Metadata []byte `db:"metadata"`
Tag string `db:"tag"`
Status channels.Status `db:"status"`
GroupID string `db:"group_id"`
ClientID string `db:"client_id"`
ConnType string `db:"type"`
RoleName string `db:"role_name"`
RoleID string `db:"role_id"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
}
type dbConnection struct {
+101 -64
View File
@@ -13,7 +13,6 @@ import (
"github.com/0x6flab/namegenerator"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/channels/postgres"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -33,7 +32,7 @@ var (
Tags: []string{"tag1", "tag2"},
Metadata: map[string]interface{}{"key": "value"},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
ConnectionTypes: []connections.ConnType{},
}
validConnection = channels.Connection{
@@ -79,7 +78,7 @@ func TestSave(t *testing.T) {
Name: namegen.Generate(),
Metadata: map[string]interface{}{"key": "value"},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
@@ -92,7 +91,7 @@ func TestSave(t *testing.T) {
Name: namegen.Generate(),
Metadata: map[string]interface{}{"key": "value"},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
@@ -105,7 +104,7 @@ func TestSave(t *testing.T) {
Name: strings.Repeat("a", 1025),
Metadata: map[string]interface{}{"key": "value"},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
@@ -120,7 +119,7 @@ func TestSave(t *testing.T) {
"key": make(chan int),
},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
@@ -306,7 +305,7 @@ func TestChangeStatus(t *testing.T) {
disabledChannel := validChannel
disabledChannel.ID = testsutil.GenerateUUID(t)
disabledChannel.Name = namegen.Generate()
disabledChannel.Status = clients.DisabledStatus
disabledChannel.Status = channels.DisabledStatus
_, err := repo.Save(context.Background(), validChannel, disabledChannel)
require.Nil(t, err, fmt.Sprintf("save channel unexpected error: %s", err))
@@ -320,7 +319,7 @@ func TestChangeStatus(t *testing.T) {
desc: "disable channel successfully",
channel: channels.Channel{
ID: validChannel.ID,
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
@@ -330,7 +329,7 @@ func TestChangeStatus(t *testing.T) {
desc: "enable channel successfully",
channel: channels.Channel{
ID: disabledChannel.ID,
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
@@ -340,7 +339,7 @@ func TestChangeStatus(t *testing.T) {
desc: "change status channel with invalid ID",
channel: channels.Channel{
ID: testsutil.GenerateUUID(t),
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
@@ -349,7 +348,7 @@ func TestChangeStatus(t *testing.T) {
{
desc: "change status channel with empty ID",
channel: channels.Channel{
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
@@ -438,7 +437,7 @@ func TestRetrieveAll(t *testing.T) {
Name: name,
Metadata: map[string]interface{}{"name": name},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
ConnectionTypes: []connections.ConnType{},
}
_, err := repo.Save(context.Background(), channel)
@@ -451,20 +450,20 @@ func TestRetrieveAll(t *testing.T) {
cases := []struct {
desc string
page channels.Page
response channels.Page
page channels.ChannelsPage
response channels.ChannelsPage
err error
}{
{
desc: "retrieve channels successfully",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 0,
Limit: 10,
@@ -475,14 +474,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with offset",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 10,
Limit: 10,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 10,
Limit: 10,
@@ -493,14 +492,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with limit",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 50,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 0,
Limit: 50,
@@ -511,14 +510,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with offset and limit",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 50,
Limit: 50,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 50,
Limit: 50,
@@ -529,14 +528,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with offset out of range",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 1000,
Limit: 50,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 1000,
Limit: 50,
@@ -547,14 +546,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with offset and limit out of range",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 170,
Limit: 50,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 170,
Limit: 50,
@@ -565,14 +564,14 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with limit out of range",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 1000,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 0,
Limit: 1000,
@@ -583,9 +582,9 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with empty page",
page: channels.Page{},
response: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{},
response: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(num),
Offset: 0,
Limit: 0,
@@ -596,15 +595,15 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with name",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
Name: items[0].Name,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
Offset: 0,
Limit: 10,
@@ -615,15 +614,15 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with domain",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
Domain: items[0].Domain,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
Offset: 0,
Limit: 10,
@@ -634,15 +633,15 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with metadata",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
Metadata: items[0].Metadata,
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
Offset: 0,
Limit: 10,
@@ -653,8 +652,8 @@ func TestRetrieveAll(t *testing.T) {
},
{
desc: "retrieve channels with invalid metadata",
page: channels.Page{
PageMetadata: channels.PageMetadata{
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
Metadata: map[string]interface{}{
@@ -662,8 +661,8 @@ func TestRetrieveAll(t *testing.T) {
},
},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 0,
Offset: 0,
Limit: 10,
@@ -672,11 +671,49 @@ func TestRetrieveAll(t *testing.T) {
},
err: errors.ErrMalformedEntity,
},
{
desc: "retrieve channels with id",
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
ID: items[0].ID,
},
},
response: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
Offset: 0,
Limit: 10,
},
Channels: []channels.Channel{items[0]},
},
err: nil,
},
{
desc: "retrieve channels with wrong id",
page: channels.ChannelsPage{
Page: channels.Page{
Offset: 0,
Limit: 10,
ID: "wrong",
},
},
response: channels.ChannelsPage{
Page: channels.Page{
Total: 0,
Offset: 0,
Limit: 10,
},
Channels: []channels.Channel(nil),
},
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
switch channels, err := repo.RetrieveAll(context.Background(), tc.page.PageMetadata); {
switch channels, err := repo.RetrieveAll(context.Background(), tc.page.Page); {
case err == nil:
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response.Total, channels.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, channels.Total))
@@ -1338,7 +1375,7 @@ func TestRetrieveParentGroupChannels(t *testing.T) {
Name: name,
Metadata: map[string]interface{}{"name": name},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
ConnectionTypes: []connections.ConnType{},
}
items = append(items, channel)
@@ -1406,7 +1443,7 @@ func TestUnsetParentGroupFromChannels(t *testing.T) {
Name: name,
Metadata: map[string]interface{}{"name": name},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
}
items = append(items, channel)
}
+15 -15
View File
@@ -13,7 +13,6 @@ import (
grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
apiutil "github.com/absmach/supermq/api/http/util"
smqclients "github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/connections"
"github.com/absmach/supermq/pkg/errors"
@@ -67,7 +66,7 @@ func (svc service) CreateChannels(ctx context.Context, session authn.Session, ch
c.ID = clientID
}
if c.Status != smqclients.DisabledStatus && c.Status != smqclients.EnabledStatus {
if c.Status != DisabledStatus && c.Status != EnabledStatus {
return []Channel{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus
}
c.Domain = session.DomainID
@@ -148,12 +147,12 @@ func (svc service) UpdateChannelTags(ctx context.Context, session authn.Session,
func (svc service) EnableChannel(ctx context.Context, session authn.Session, id string) (Channel, error) {
channel := Channel{
ID: id,
Status: smqclients.EnabledStatus,
Status: EnabledStatus,
UpdatedAt: time.Now(),
}
ch, err := svc.changeChannelStatus(ctx, session.UserID, channel)
if err != nil {
return Channel{}, errors.Wrap(smqclients.ErrEnableClient, err)
return Channel{}, errors.Wrap(ErrEnableChannel, err)
}
return ch, nil
@@ -162,12 +161,12 @@ func (svc service) EnableChannel(ctx context.Context, session authn.Session, id
func (svc service) DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error) {
channel := Channel{
ID: id,
Status: smqclients.DisabledStatus,
Status: DisabledStatus,
UpdatedAt: time.Now(),
}
ch, err := svc.changeChannelStatus(ctx, session.UserID, channel)
if err != nil {
return Channel{}, errors.Wrap(smqclients.ErrDisableClient, err)
return Channel{}, errors.Wrap(ErrDisableChannel, err)
}
return ch, nil
@@ -178,30 +177,31 @@ func (svc service) ViewChannel(ctx context.Context, session authn.Session, id st
if err != nil {
return Channel{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return channel, nil
}
func (svc service) ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error) {
func (svc service) ListChannels(ctx context.Context, session authn.Session, pm Page) (ChannelsPage, error) {
switch session.SuperAdmin {
case true:
cp, err := svc.repo.RetrieveAll(ctx, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
return ChannelsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
default:
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, session.UserID, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
return ChannelsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
}
}
func (svc service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error) {
func (svc service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm Page) (ChannelsPage, error) {
cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, userID, pm)
if err != nil {
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
return ChannelsPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cp, nil
}
@@ -217,7 +217,7 @@ func (svc service) RemoveChannel(ctx context.Context, session authn.Session, id
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
}
ch, err := svc.repo.ChangeStatus(ctx, Channel{ID: id, Status: smqclients.DeletedStatus})
ch, err := svc.repo.ChangeStatus(ctx, Channel{ID: id, Status: DeletedStatus})
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
@@ -270,7 +270,7 @@ func (svc service) Connect(ctx context.Context, session authn.Session, chIDs, th
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
if c.Status != smqclients.EnabledStatus {
if c.Status != EnabledStatus {
return errors.Wrap(svcerr.ErrCreateEntity, fmt.Errorf("channel id %s is not in enabled state", chID))
}
if c.Domain != session.DomainID {
@@ -283,7 +283,7 @@ func (svc service) Connect(ctx context.Context, session authn.Session, chIDs, th
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
if resp.GetEntity().GetStatus() != uint32(smqclients.EnabledStatus) {
if resp.GetEntity().GetStatus() != uint32(EnabledStatus) {
return errors.Wrap(svcerr.ErrCreateEntity, fmt.Errorf("client id %s is not in enabled state", thID))
}
if resp.GetEntity().GetDomainId() != session.DomainID {
@@ -399,7 +399,7 @@ func (svc service) SetParentGroup(ctx context.Context, session authn.Session, pa
if resp.GetEntity().GetDomainId() != session.DomainID {
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s has invalid domain id", parentGroupID))
}
if resp.GetEntity().GetStatus() != uint32(smqclients.EnabledStatus) {
if resp.GetEntity().GetStatus() != uint32(EnabledStatus) {
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s is not in enabled state", parentGroupID))
}
+58 -60
View File
@@ -15,8 +15,6 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/channels/mocks"
"github.com/absmach/supermq/clients"
smqclients "github.com/absmach/supermq/clients"
clmocks "github.com/absmach/supermq/clients/mocks"
gpmocks "github.com/absmach/supermq/groups/mocks"
"github.com/absmach/supermq/internal/testsutil"
@@ -45,7 +43,7 @@ var (
},
Tags: []string{"tag1", "tag2"},
Domain: testsutil.GenerateUUID(&testing.T{}),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
}
parentGroupID = testsutil.GenerateUUID(&testing.T{})
validID = testsutil.GenerateUUID(&testing.T{})
@@ -67,7 +65,7 @@ func newService(t *testing.T) channels.Service {
groupsSvc = new(gpmocks.GroupsServiceClient)
availableActions := []roles.Action{}
builtInRoles := map[roles.BuiltInRoleName][]roles.Action{
clients.BuiltInRoleAdmin: availableActions,
channels.BuiltInRoleAdmin: availableActions,
}
svc, err := channels.New(repo, policies, idProvider, clientsSvc, groupsSvc, idProvider, availableActions, builtInRoles)
assert.Nil(t, err, fmt.Sprintf(" Unexpected error while creating service %v", err))
@@ -102,7 +100,7 @@ func TestCreateChannel(t *testing.T) {
desc: "create channel with invalid status",
channel: channels.Channel{
Name: namegen.Generate(),
Status: clients.Status(100),
Status: channels.Status(100),
},
err: svcerr.ErrInvalidStatus,
},
@@ -110,7 +108,7 @@ func TestCreateChannel(t *testing.T) {
desc: "create channel successfully with parent",
channel: channels.Channel{
Name: namegen.Generate(),
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
ParentGroup: testsutil.GenerateUUID(t),
},
saveResp: []channels.Channel{
@@ -353,7 +351,7 @@ func TestEnableChannel(t *testing.T) {
desc: "enable channel successfully",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
},
changeResp: validChannel,
},
@@ -361,7 +359,7 @@ func TestEnableChannel(t *testing.T) {
desc: "enable channel with enabled channel",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
},
@@ -376,7 +374,7 @@ func TestEnableChannel(t *testing.T) {
desc: "enable channel with change status error",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
},
changeErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
@@ -416,7 +414,7 @@ func TestDisableChannel(t *testing.T) {
desc: "disable channel successfully",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
changeResp: validChannel,
},
@@ -424,7 +422,7 @@ func TestDisableChannel(t *testing.T) {
desc: "disable channel with disabled channel",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{
Status: clients.DisabledStatus,
Status: channels.DisabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
},
@@ -438,7 +436,7 @@ func TestDisableChannel(t *testing.T) {
{
desc: "disable channel with change status error",
id: testsutil.GenerateUUID(t),
retrieveResp: channels.Channel{Status: clients.EnabledStatus},
retrieveResp: channels.Channel{Status: channels.EnabledStatus},
changeErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
@@ -472,9 +470,9 @@ func TestListChannels(t *testing.T) {
desc string
userKind string
session smqauthn.Session
page channels.PageMetadata
retrieveAllResponse channels.Page
response channels.Page
page channels.Page
retrieveAllResponse channels.ChannelsPage
response channels.ChannelsPage
id string
size uint64
listObjectsErr error
@@ -487,20 +485,20 @@ func TestListChannels(t *testing.T) {
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
},
retrieveAllResponse: channels.Page{
PageMetadata: channels.PageMetadata{
retrieveAllResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 2,
Offset: 0,
Limit: 100,
@@ -514,12 +512,12 @@ func TestListChannels(t *testing.T) {
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
},
retrieveAllResponse: channels.Page{},
response: channels.Page{},
retrieveAllResponse: channels.ChannelsPage{},
response: channels.ChannelsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -528,23 +526,23 @@ func TestListChannels(t *testing.T) {
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
},
response: channels.Page{},
response: channels.ChannelsPage{},
err: nil,
},
{
desc: "list all channels as non admin with failed to list objects",
userKind: "non-admin",
id: nonAdminID,
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
},
retrieveAllErr: repoerr.ErrNotFound,
response: channels.Page{},
response: channels.ChannelsPage{},
listObjectsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -564,9 +562,9 @@ func TestListChannels(t *testing.T) {
desc string
userKind string
session smqauthn.Session
page channels.PageMetadata
retrieveAllResponse channels.Page
response channels.Page
page channels.Page
retrieveAllResponse channels.ChannelsPage
response channels.ChannelsPage
id string
size uint64
listObjectsErr error
@@ -579,21 +577,21 @@ func TestListChannels(t *testing.T) {
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{
PageMetadata: channels.PageMetadata{
retrieveAllResponse: channels.ChannelsPage{
Page: channels.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Channels: []channels.Channel{validChannel, validChannel},
},
response: channels.Page{
PageMetadata: channels.PageMetadata{
response: channels.ChannelsPage{
Page: channels.Page{
Total: 2,
Offset: 0,
Limit: 100,
@@ -607,12 +605,12 @@ func TestListChannels(t *testing.T) {
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{},
retrieveAllResponse: channels.ChannelsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -621,12 +619,12 @@ func TestListChannels(t *testing.T) {
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: channels.PageMetadata{
page: channels.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: channels.Page{},
retrieveAllResponse: channels.ChannelsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -645,7 +643,7 @@ func TestRemoveChannel(t *testing.T) {
svc := newService(t)
deletedChannel := validChannel
deletedChannel.Status = clients.DeletedStatus
deletedChannel.Status = channels.DeletedStatus
channelWithParent := deletedChannel
channelWithParent.ParentGroup = testsutil.GenerateUUID(t)
@@ -732,7 +730,7 @@ func TestRemoveChannel(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("DoesChannelHaveConnections", context.Background(), validChannel.ID).Return(tc.connectionsRes, tc.connectionsErr)
clientsCall := clientsSvc.On("RemoveChannelConnections", context.Background(), &grpcClientsV1.RemoveChannelConnectionsReq{ChannelId: tc.id}).Return(&grpcClientsV1.RemoveChannelConnectionsRes{}, tc.removeConnectionsErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), channels.Channel{ID: tc.id, Status: smqclients.DeletedStatus}).Return(tc.changeStatusRes, tc.changeStatusErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), channels.Channel{ID: tc.id, Status: channels.DeletedStatus}).Return(tc.changeStatusRes, tc.changeStatusErr)
repoCall2 := repo.On("RetrieveEntitiesRolesActionsMembers", context.Background(), []string{tc.id}).Return([]roles.EntityActionRole{}, []roles.EntityMemberRole{}, nil)
policyCall := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyFilterErr)
@@ -757,7 +755,7 @@ func TestConnect(t *testing.T) {
validDomainChannel.Domain = validID
disabledChannel := validChannel
disabledChannel.Status = clients.DisabledStatus
disabledChannel.Status = channels.DisabledStatus
cases := []struct {
desc string
@@ -785,7 +783,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
checkConnErr: repoerr.ErrNotFound,
@@ -845,7 +843,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.DisabledStatus),
Status: uint32(channels.DisabledStatus),
},
},
err: svcerr.ErrCreateEntity,
@@ -859,7 +857,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: testsutil.GenerateUUID(t),
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
err: svcerr.ErrCreateEntity,
@@ -874,7 +872,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -896,7 +894,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -918,7 +916,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -949,7 +947,7 @@ func TestConnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -1021,7 +1019,7 @@ func TestDisconnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -1073,7 +1071,7 @@ func TestDisconnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: testsutil.GenerateUUID(t),
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
err: svcerr.ErrRemoveEntity,
@@ -1088,7 +1086,7 @@ func TestDisconnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -1118,7 +1116,7 @@ func TestDisconnect(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
repoConn: channels.Connection{
@@ -1188,7 +1186,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
err: nil,
@@ -1219,7 +1217,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: testsutil.GenerateUUID(t),
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
err: svcerr.ErrUpdateEntity,
@@ -1233,7 +1231,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.DisabledStatus),
Status: uint32(channels.DisabledStatus),
},
},
err: svcerr.ErrUpdateEntity,
@@ -1247,7 +1245,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
err: svcerr.ErrConflict,
@@ -1261,7 +1259,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
addPoliciesErr: svcerr.ErrAuthorization,
@@ -1276,7 +1274,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
setParentGroupErr: repoerr.ErrNotFound,
@@ -1291,7 +1289,7 @@ func TestSetParentGroup(t *testing.T) {
Entity: &grpcCommonV1.EntityBasic{
Id: parentGroupID,
DomainId: validID,
Status: uint32(clients.EnabledStatus),
Status: uint32(channels.EnabledStatus),
},
},
setParentGroupErr: repoerr.ErrNotFound,
+94
View File
@@ -0,0 +1,94 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
import (
"encoding/json"
"strings"
svcerr "github.com/absmach/supermq/pkg/errors/service"
)
// Status represents Channel status.
type Status uint8
// Possible Channel status values.
const (
// EnabledStatus represents enabled Channel.
EnabledStatus Status = iota
// DisabledStatus represents disabled Channel.
DisabledStatus
// DeletedStatus represents deleted Channel.
DeletedStatus
// AllStatus is used for querying purposes to list channels irrespective
// of their status - both active and inactive. It is never stored in the
// database as the actual Channel status and should always be the largest
// value in this enumeration.
AllStatus
)
// String representation of the possible status values.
const (
Disabled = "disabled"
Enabled = "enabled"
Deleted = "deleted"
All = "all"
Unknown = "unknown"
)
// String converts Channel status to string literal.
func (s Status) String() string {
switch s {
case DisabledStatus:
return Disabled
case EnabledStatus:
return Enabled
case DeletedStatus:
return Deleted
case AllStatus:
return All
default:
return Unknown
}
}
// ToStatus converts string value to a valid Channel status.
func ToStatus(status string) (Status, error) {
switch status {
case Disabled:
return DisabledStatus, nil
case Enabled:
return EnabledStatus, nil
case Deleted:
return DeletedStatus, nil
case All:
return AllStatus, nil
}
return Status(0), svcerr.ErrInvalidStatus
}
// Custom Marshaller for Status.
func (s Status) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String())
}
func (channel Channel) MarshalJSON() ([]byte, error) {
type Alias Channel
return json.Marshal(&struct {
Alias
Status string `json:"status,omitempty"`
}{
Alias: (Alias)(channel),
Status: channel.Status.String(),
})
}
// Custom Unmarshaler for Status.
func (s *Status) UnmarshalJSON(data []byte) error {
str := strings.Trim(string(data), "\"")
val, err := ToStatus(str)
*s = val
return err
}
+246
View File
@@ -0,0 +1,246 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels_test
import (
"testing"
"github.com/absmach/supermq/channels"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/stretchr/testify/assert"
)
func TestStatusString(t *testing.T) {
cases := []struct {
desc string
status channels.Status
expected string
}{
{
desc: "Enabled",
status: channels.EnabledStatus,
expected: "enabled",
},
{
desc: "Disabled",
status: channels.DisabledStatus,
expected: "disabled",
},
{
desc: "Deleted",
status: channels.DeletedStatus,
expected: "deleted",
},
{
desc: "All",
status: channels.AllStatus,
expected: "all",
},
{
desc: "Unknown",
status: channels.Status(100),
expected: "unknown",
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got := tc.status.String()
assert.Equal(t, tc.expected, got, "String() = %v, expected %v", got, tc.expected)
})
}
}
func TestToStatus(t *testing.T) {
cases := []struct {
desc string
status string
expetcted channels.Status
err error
}{
{
desc: "Enabled",
status: "enabled",
expetcted: channels.EnabledStatus,
err: nil,
},
{
desc: "Disabled",
status: "disabled",
expetcted: channels.DisabledStatus,
err: nil,
},
{
desc: "Deleted",
status: "deleted",
expetcted: channels.DeletedStatus,
err: nil,
},
{
desc: "All",
status: "all",
expetcted: channels.AllStatus,
err: nil,
},
{
desc: "Unknown",
status: "unknown",
expetcted: channels.Status(0),
err: svcerr.ErrInvalidStatus,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := channels.ToStatus(tc.status)
assert.Equal(t, tc.err, err, "ToStatus() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expetcted, got, "ToStatus() = %v, expected %v", got, tc.expetcted)
})
}
}
func TestStatusMarshalJSON(t *testing.T) {
cases := []struct {
desc string
expected []byte
status channels.Status
err error
}{
{
desc: "Enabled",
expected: []byte(`"enabled"`),
status: channels.EnabledStatus,
err: nil,
},
{
desc: "Disabled",
expected: []byte(`"disabled"`),
status: channels.DisabledStatus,
err: nil,
},
{
desc: "Deleted",
expected: []byte(`"deleted"`),
status: channels.DeletedStatus,
err: nil,
},
{
desc: "All",
expected: []byte(`"all"`),
status: channels.AllStatus,
err: nil,
},
{
desc: "Unknown",
expected: []byte(`"unknown"`),
status: channels.Status(100),
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := tc.status.MarshalJSON()
assert.Equal(t, tc.err, err, "MarshalJSON() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expected, got, "MarshalJSON() = %v, expected %v", got, tc.expected)
})
}
}
func TestStatusUnmarshalJSON(t *testing.T) {
cases := []struct {
desc string
expected channels.Status
status []byte
err error
}{
{
desc: "Enabled",
expected: channels.EnabledStatus,
status: []byte(`"enabled"`),
err: nil,
},
{
desc: "Disabled",
expected: channels.DisabledStatus,
status: []byte(`"disabled"`),
err: nil,
},
{
desc: "Deleted",
expected: channels.DeletedStatus,
status: []byte(`"deleted"`),
err: nil,
},
{
desc: "All",
expected: channels.AllStatus,
status: []byte(`"all"`),
err: nil,
},
{
desc: "Unknown",
expected: channels.Status(0),
status: []byte(`"unknown"`),
err: svcerr.ErrInvalidStatus,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var s channels.Status
err := s.UnmarshalJSON(tc.status)
assert.Equal(t, tc.err, err, "UnmarshalJSON() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expected, s, "UnmarshalJSON() = %v, expected %v", s, tc.expected)
})
}
}
func TestChannelMarshalJSON(t *testing.T) {
cases := []struct {
desc string
expected []byte
user channels.Channel
err error
}{
{
desc: "Enabled",
expected: []byte(`{"id":"","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","status":"enabled"}`),
user: channels.Channel{Status: channels.EnabledStatus},
err: nil,
},
{
desc: "Disabled",
expected: []byte(`{"id":"","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","status":"disabled"}`),
user: channels.Channel{Status: channels.DisabledStatus},
err: nil,
},
{
desc: "Deleted",
expected: []byte(`{"id":"","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","status":"deleted"}`),
user: channels.Channel{Status: channels.DeletedStatus},
err: nil,
},
{
desc: "All",
expected: []byte(`{"id":"","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","status":"all"}`),
user: channels.Channel{Status: channels.AllStatus},
err: nil,
},
{
desc: "Unknown",
expected: []byte(`{"id":"","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","status":"unknown"}`),
user: channels.Channel{Status: channels.Status(100)},
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := tc.user.MarshalJSON()
assert.Equal(t, tc.err, err, "MarshalJSON() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expected, got, "MarshalJSON() = %v, expected %v", string(got), string(tc.expected))
})
}
}
+2 -2
View File
@@ -45,13 +45,13 @@ func (tm *tracingMiddleware) ViewChannel(ctx context.Context, session authn.Sess
}
// ListChannels traces the "ListChannels" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_channels")
defer span.End()
return tm.svc.ListChannels(ctx, session, pm)
}
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_user_channels")
defer span.End()
return tm.svc.ListUserChannels(ctx, session, userID, pm)
+24 -16
View File
@@ -116,23 +116,31 @@ func decodeListClients(_ context.Context, r *http.Request) (interface{}, error)
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req := listClientsReq{
name: name,
tag: tag,
status: status,
metadata: meta,
roleName: roleName,
roleID: roleID,
actions: actions,
accessType: accessType,
order: order,
dir: dir,
offset: offset,
limit: limit,
groupID: groupID,
channelID: channelID,
connType: connType,
userID: userID,
Page: clients.Page{
Name: name,
Tag: tag,
Status: status,
Metadata: meta,
RoleName: roleName,
RoleID: roleID,
Actions: actions,
AccessType: accessType,
Order: order,
Dir: dir,
Offset: offset,
Limit: limit,
Group: groupID,
Channel: channelID,
ConnectionType: connType,
ID: id,
},
userID: userID,
}
return req, nil
}
+2 -20
View File
@@ -103,31 +103,13 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthentication
}
pm := clients.Page{
Name: req.name,
Tag: req.tag,
Status: req.status,
Metadata: req.metadata,
RoleName: req.roleName,
RoleID: req.roleID,
Actions: req.actions,
AccessType: req.accessType,
Order: req.order,
Dir: req.dir,
Offset: req.offset,
Limit: req.limit,
Group: req.groupID,
Channel: req.channelID,
ConnectionType: req.connType,
}
var page clients.ClientsPage
var err error
switch req.userID != "" {
case true:
page, err = svc.ListUserClients(ctx, session, req.userID, pm)
page, err = svc.ListUserClients(ctx, session, req.userID, req.Page)
default:
page, err = svc.ListClients(ctx, session, pm)
page, err = svc.ListClients(ctx, session, req.Page)
}
if err != nil {
return clientsPageRes{}, err
+4 -18
View File
@@ -71,30 +71,16 @@ func (req viewClientPermsReq) validate() error {
}
type listClientsReq struct {
name string
tag string
status clients.Status
metadata clients.Metadata
roleName string
roleID string
actions []string
accessType string
order string
dir string
offset uint64
limit uint64
groupID string
channelID string
connType string
userID string
clients.Page
userID string
}
func (req listClientsReq) validate() error {
if req.limit > api.MaxLimitSize || req.limit < 1 {
if req.Limit > api.MaxLimitSize || req.Limit < 1 {
return apiutil.ErrLimitSize
}
if len(req.name) > api.MaxNameSize {
if len(req.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
+7 -5
View File
@@ -192,29 +192,31 @@ func TestListClientsReqValidate(t *testing.T) {
{
desc: "valid request",
req: listClientsReq{
limit: 10,
Page: clients.Page{Limit: 10},
},
err: nil,
},
{
desc: "limit too big",
req: listClientsReq{
limit: api.MaxLimitSize + 1,
Page: clients.Page{Limit: api.MaxLimitSize + 1},
},
err: apiutil.ErrLimitSize,
},
{
desc: "limit too small",
req: listClientsReq{
limit: 0,
Page: clients.Page{Limit: 0},
},
err: apiutil.ErrLimitSize,
},
{
desc: "name too long",
req: listClientsReq{
limit: 10,
name: strings.Repeat("a", api.MaxNameSize+1),
Page: clients.Page{
Limit: 10,
Name: strings.Repeat("a", api.MaxNameSize+1),
},
},
err: apiutil.ErrNameSize,
},
+1 -1
View File
@@ -201,7 +201,7 @@ type Page struct {
Limit uint64 `json:"limit"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
Id string `json:"id,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
+3 -3
View File
@@ -930,7 +930,7 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) {
Limit: pm.Limit,
Name: pm.Name,
Identity: pm.Identity,
Id: pm.Id,
Id: pm.ID,
Metadata: data,
Domain: pm.Domain,
Status: pm.Status,
@@ -977,8 +977,8 @@ func PageQuery(pm clients.Page) (string, error) {
if pm.Identity != "" {
query = append(query, "c.identity ILIKE '%' || :identity || '%'")
}
if pm.Id != "" {
query = append(query, "c.id ILIKE '%' || :id || '%'")
if pm.ID != "" {
query = append(query, "c.id = :id")
}
if pm.Tag != "" {
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
+34
View File
@@ -1229,6 +1229,40 @@ func TestRetrieveAll(t *testing.T) {
Clients: []clients.Client{expectedClients[0]},
},
},
{
desc: "with id",
pm: clients.Page{
Offset: 0,
Limit: nClients,
ID: expectedClients[0].ID,
Status: clients.AllStatus,
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 1,
Offset: 0,
Limit: nClients,
},
Clients: []clients.Client{expectedClients[0]},
},
},
{
desc: "with wrong id",
pm: clients.Page{
Offset: 0,
Limit: nClients,
ID: testsutil.GenerateUUID(t),
Status: clients.AllStatus,
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 0,
Offset: 0,
Limit: nClients,
},
Clients: []clients.Client(nil),
},
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
+32 -26
View File
@@ -66,7 +66,7 @@ func decodeListDomainRequest(ctx context.Context, r *http.Request) (interface{},
return nil, err
}
req := listDomainsReq{
page: page,
page,
}
return req, nil
@@ -93,47 +93,47 @@ func decodeFreezeDomainRequest(_ context.Context, r *http.Request) (interface{},
return req, nil
}
func decodePageRequest(_ context.Context, r *http.Request) (page, error) {
func decodePageRequest(_ context.Context, r *http.Request) (domains.Page, error) {
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
st, err := domains.ToStatus(s)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
or, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
t, err := apiutil.ReadStringQuery(r, api.TagKey, "")
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
var actions []string
@@ -144,26 +144,32 @@ func decodePageRequest(_ context.Context, r *http.Request) (page, error) {
}
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
if err != nil {
return page{}, errors.Wrap(apiutil.ErrValidation, err)
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
return page{
offset: o,
order: or,
dir: dir,
limit: l,
name: n,
metadata: m,
tag: t,
roleID: roleID,
roleName: roleName,
actions: actions,
status: st,
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return domains.Page{}, errors.Wrap(apiutil.ErrValidation, err)
}
return domains.Page{
Offset: o,
Order: or,
Dir: dir,
Limit: l,
Name: n,
Metadata: m,
Tag: t,
RoleID: roleID,
RoleName: roleName,
Actions: actions,
Status: st,
ID: id,
}, nil
}
+1 -14
View File
@@ -109,20 +109,7 @@ func listDomainsEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthorization
}
page := domains.Page{
Offset: req.offset,
Limit: req.limit,
Name: req.name,
Metadata: req.metadata,
Order: req.order,
Dir: req.dir,
Tag: req.tag,
RoleID: req.roleID,
RoleName: req.roleName,
Actions: req.actions,
Status: req.status,
}
dp, err := svc.ListDomains(ctx, session, page)
dp, err := svc.ListDomains(ctx, session, req.Page)
if err != nil {
return nil, err
}
+1 -15
View File
@@ -11,20 +11,6 @@ import (
const maxLimitSize = 100
type page struct {
offset uint64
limit uint64
order string
dir string
name string
metadata map[string]interface{}
tag string
roleID string
roleName string
actions []string
status domains.Status
}
type createDomainReq struct {
ID string `json:"id,omitempty"`
Name string `json:"name"`
@@ -76,7 +62,7 @@ func (req updateDomainReq) validate() error {
}
type listDomainsReq struct {
page
domains.Page
}
func (req listDomainsReq) validate() error {
+2 -2
View File
@@ -596,7 +596,7 @@ func buildPageQuery(pm domains.Page) (string, error) {
}
if pm.Name != "" {
query = append(query, "d.name = :name")
query = append(query, "d.name ILIKE '%' || :name || '%'")
}
if pm.UserID != "" {
@@ -614,7 +614,7 @@ func buildPageQuery(pm domains.Page) (string, error) {
}
if pm.Tag != "" {
query = append(query, ":tag = ANY(d.tags)")
query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')")
}
mq, _, err := postgres.CreateMetadataQuery("", pm.Metadata)
+63
View File
@@ -692,6 +692,22 @@ func TestListDomains(t *testing.T) {
},
err: nil,
},
{
desc: "list all domains with name",
pm: domains.Page{
Offset: 0,
Limit: 10,
Name: items[0].Name,
Status: domains.AllStatus,
},
response: domains.DomainsPage{
Total: 1,
Offset: 0,
Limit: 10,
Domains: []domains.Domain{items[0]},
},
err: nil,
},
{
desc: "list all domains with disabled status",
pm: domains.Page{
@@ -723,6 +739,22 @@ func TestListDomains(t *testing.T) {
},
err: nil,
},
{
desc: "list all domains with invalid tag",
pm: domains.Page{
Offset: 0,
Limit: 10,
Tag: "invalid",
Status: domains.AllStatus,
},
response: domains.DomainsPage{
Total: 0,
Offset: 0,
Limit: 10,
Domains: []domains.Domain(nil),
},
err: nil,
},
{
desc: "list all domains with metadata",
pm: domains.Page{
@@ -769,6 +801,37 @@ func TestListDomains(t *testing.T) {
},
err: nil,
},
{
desc: "list domains with id",
pm: domains.Page{
Offset: 0,
Limit: 10,
ID: items[0].ID,
Status: domains.AllStatus,
},
response: domains.DomainsPage{
Total: 1,
Offset: 0,
Limit: 10,
Domains: []domains.Domain{items[0]},
},
err: nil,
},
{
desc: "list domains with invalid id",
pm: domains.Page{
Offset: 0,
Limit: 10,
ID: invalid,
Status: domains.AllStatus,
},
response: domains.DomainsPage{
Total: 0,
Offset: 0,
Limit: 10,
},
err: nil,
},
}
for _, tc := range cases {
+1 -1
View File
@@ -900,7 +900,7 @@ func buildQuery(gm groups.PageMeta, ids ...string) string {
queries = append(queries, "g.name ILIKE '%' || :name || '%'")
}
if gm.ID != "" {
queries = append(queries, "g.id ILIKE '%' || :id || '%'")
queries = append(queries, "g.id = :id")
}
if gm.Status != groups.AllStatus {
queries = append(queries, "g.status = :status")
+38
View File
@@ -823,6 +823,44 @@ func TestRetrieveAll(t *testing.T) {
},
err: errors.ErrMalformedEntity,
},
{
desc: "retrieve groups with id",
page: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
Limit: 10,
ID: items[0].ID,
},
},
response: groups.Page{
PageMeta: groups.PageMeta{
Total: 1,
Offset: 0,
Limit: 10,
},
Groups: []groups.Group{items[0]},
},
err: nil,
},
{
desc: "retrieve groups with wrong id",
page: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
Limit: 10,
ID: "wrong",
},
},
response: groups.Page{
PageMeta: groups.PageMeta{
Total: 0,
Offset: 0,
Limit: 10,
},
Groups: []groups.Group(nil),
},
err: nil,
},
}
for _, tc := range cases {
+54 -62
View File
@@ -15,7 +15,6 @@ import (
"github.com/absmach/supermq/channels"
chapi "github.com/absmach/supermq/channels/api/http"
chmocks "github.com/absmach/supermq/channels/mocks"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
@@ -54,14 +53,14 @@ func TestCreateChannel(t *testing.T) {
createChannelReq := channels.Channel{
Name: channel.Name,
Metadata: clients.Metadata{"role": "client"},
Status: clients.EnabledStatus,
Metadata: channels.Metadata{"role": "client"},
Status: channels.EnabledStatus,
}
channelReq := sdk.Channel{
Name: channel.Name,
Metadata: validMetadata,
Status: clients.EnabledStatus.String(),
Status: channels.EnabledStatus.String(),
}
parentID := testsutil.GenerateUUID(&testing.T{})
@@ -69,7 +68,7 @@ func TestCreateChannel(t *testing.T) {
pChannel.ParentGroup = parentID
iChannel := convertChannel(channel)
iChannel.Metadata = clients.Metadata{
iChannel.Metadata = channels.Metadata{
"test": make(chan int),
}
@@ -135,14 +134,14 @@ func TestCreateChannel(t *testing.T) {
channelReq: sdk.Channel{
Name: channel.Name,
ParentGroup: parentID,
Status: clients.EnabledStatus.String(),
Status: channels.EnabledStatus.String(),
},
domainID: domainID,
token: validToken,
createChannelReq: channels.Channel{
Name: channel.Name,
ParentGroup: parentID,
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
svcRes: []channels.Channel{convertChannel(pChannel)},
svcErr: nil,
@@ -154,14 +153,14 @@ func TestCreateChannel(t *testing.T) {
channelReq: sdk.Channel{
Name: channel.Name,
ParentGroup: wrongID,
Status: clients.EnabledStatus.String(),
Status: channels.EnabledStatus.String(),
},
domainID: domainID,
token: validToken,
createChannelReq: channels.Channel{
Name: channel.Name,
ParentGroup: wrongID,
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
svcRes: []channels.Channel{},
svcErr: svcerr.ErrCreateEntity,
@@ -177,7 +176,7 @@ func TestCreateChannel(t *testing.T) {
Metadata: validMetadata,
CreatedAt: channel.CreatedAt,
UpdatedAt: channel.UpdatedAt,
Status: clients.EnabledStatus.String(),
Status: channels.EnabledStatus.String(),
},
domainID: domainID,
token: validToken,
@@ -185,10 +184,10 @@ func TestCreateChannel(t *testing.T) {
ID: channel.ID,
ParentGroup: parentID,
Name: channel.Name,
Metadata: clients.Metadata{"role": "client"},
Metadata: channels.Metadata{"role": "client"},
CreatedAt: channel.CreatedAt,
UpdatedAt: channel.UpdatedAt,
Status: clients.EnabledStatus,
Status: channels.EnabledStatus,
},
svcRes: []channels.Channel{convertChannel(pChannel)},
svcErr: nil,
@@ -238,11 +237,7 @@ func TestCreateChannels(t *testing.T) {
mgsdk := sdk.NewSDK(conf)
for i := 0; i < 3; i++ {
gr := sdk.Channel{
ID: generateUUID(t),
Name: fmt.Sprintf("channel_%d", i),
Metadata: sdk.Metadata{"name": fmt.Sprintf("client_%d", i)},
}
gr := generateTestChannel(t)
chs = append(chs, gr)
}
@@ -320,7 +315,7 @@ func TestCreateChannels(t *testing.T) {
svcRes: []channels.Channel{
{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -358,11 +353,7 @@ func TestListChannels(t *testing.T) {
mgsdk := sdk.NewSDK(conf)
for i := 10; i < 100; i++ {
gr := sdk.Channel{
ID: generateUUID(t),
Name: fmt.Sprintf("channel_%d", i),
Metadata: sdk.Metadata{"name": fmt.Sprintf("client_%d", i)},
}
gr := generateTestChannel(t)
chs = append(chs, gr)
}
@@ -371,15 +362,15 @@ func TestListChannels(t *testing.T) {
domainID string
token string
session smqauthn.Session
status clients.Status
status channels.Status
total uint64
offset uint64
limit uint64
level int
name string
metadata sdk.Metadata
channelsPageMeta channels.PageMetadata
svcRes channels.Page
channelsPageMeta channels.Page
svcRes channels.ChannelsPage
svcErr error
authenticateRes smqauthn.Session
authenticateErr error
@@ -393,15 +384,15 @@ func TestListChannels(t *testing.T) {
limit: limit,
offset: offset,
total: total,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: limit,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
svcRes: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(len(chs[offset:limit])),
},
Channels: convertChannels(chs[offset:limit]),
@@ -420,14 +411,14 @@ func TestListChannels(t *testing.T) {
domainID: domainID,
offset: offset,
limit: limit,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: limit,
},
svcRes: channels.Page{},
svcRes: channels.ChannelsPage{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
@@ -438,12 +429,12 @@ func TestListChannels(t *testing.T) {
domainID: validID,
offset: offset,
limit: limit,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcRes: channels.ChannelsPage{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
@@ -454,15 +445,15 @@ func TestListChannels(t *testing.T) {
domainID: domainID,
offset: offset,
limit: 0,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 10,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
svcRes: channels.ChannelsPage{
Page: channels.Page{
Total: uint64(len(chs[offset:])),
},
Channels: convertChannels(chs[offset:limit]),
@@ -482,12 +473,12 @@ func TestListChannels(t *testing.T) {
domainID: domainID,
offset: offset,
limit: 110,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcRes: channels.ChannelsPage{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
@@ -499,15 +490,15 @@ func TestListChannels(t *testing.T) {
offset: 0,
limit: 1,
level: 1,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 1,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
svcRes: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: convertChannels(chs[0:1]),
@@ -528,16 +519,16 @@ func TestListChannels(t *testing.T) {
offset: 0,
limit: 10,
metadata: sdk.Metadata{"name": "client_89"},
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: offset,
Limit: 10,
Metadata: clients.Metadata{"name": "client_89"},
Metadata: channels.Metadata{"name": "client_89"},
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
svcRes: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: convertChannels([]sdk.Channel{chs[89]}),
@@ -560,12 +551,12 @@ func TestListChannels(t *testing.T) {
metadata: sdk.Metadata{
"test": make(chan int),
},
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
},
svcRes: channels.Page{},
svcRes: channels.ChannelsPage{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
@@ -576,20 +567,20 @@ func TestListChannels(t *testing.T) {
domainID: domainID,
offset: 0,
limit: 10,
channelsPageMeta: channels.PageMetadata{
channelsPageMeta: channels.Page{
Actions: []string{},
Order: "updated_at",
Dir: "asc",
Offset: 0,
Limit: 10,
},
svcRes: channels.Page{
PageMetadata: channels.PageMetadata{
svcRes: channels.ChannelsPage{
Page: channels.Page{
Total: 1,
},
Channels: []channels.Channel{{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
}},
@@ -705,7 +696,7 @@ func TestViewChannel(t *testing.T) {
channelID: channelRes.ID,
svcRes: channels.Channel{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -745,7 +736,7 @@ func TestUpdateChannel(t *testing.T) {
mgsdk := sdk.NewSDK(conf)
mChannel := convertChannel(channel)
mChannel.Metadata = clients.Metadata{
mChannel.Metadata = channels.Metadata{
"field": "value2",
}
msdkChannel := channel
@@ -760,7 +751,7 @@ func TestUpdateChannel(t *testing.T) {
aChannel := convertChannel(channel)
aChannel.Name = newName
aChannel.Metadata = clients.Metadata{"field": "value2"}
aChannel.Metadata = channels.Metadata{"field": "value2"}
asdkChannel := channel
asdkChannel.Name = newName
asdkChannel.Metadata = sdk.Metadata{"field": "value2"}
@@ -807,7 +798,7 @@ func TestUpdateChannel(t *testing.T) {
},
updateChannelReq: channels.Channel{
ID: channel.ID,
Metadata: clients.Metadata{"field": "value2"},
Metadata: channels.Metadata{"field": "value2"},
},
svcRes: mChannel,
svcErr: nil,
@@ -827,7 +818,7 @@ func TestUpdateChannel(t *testing.T) {
ID: channel.ID,
Name: newName,
Metadata: clients.Metadata{"field": "value2"},
Metadata: channels.Metadata{"field": "value2"},
},
svcRes: aChannel,
svcErr: nil,
@@ -878,7 +869,7 @@ func TestUpdateChannel(t *testing.T) {
},
updateChannelReq: channels.Channel{
ID: wrongID,
Metadata: clients.Metadata{"field": "value2"},
Metadata: channels.Metadata{"field": "value2"},
},
svcRes: channels.Channel{},
svcErr: svcerr.ErrNotFound,
@@ -964,7 +955,7 @@ func TestUpdateChannel(t *testing.T) {
},
svcRes: channels.Channel{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -1128,7 +1119,7 @@ func TestUpdateChannelTags(t *testing.T) {
svcRes: channels.Channel{
Name: updatedChannel.Name,
Tags: updatedChannel.Tags,
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -1235,7 +1226,7 @@ func TestEnableChannel(t *testing.T) {
channelID: channel.ID,
svcRes: channels.Channel{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -1274,7 +1265,7 @@ func TestDisableChannel(t *testing.T) {
mgsdk := sdk.NewSDK(conf)
dChannel := channel
dChannel.Status = clients.DisabledStatus.String()
dChannel.Status = channels.DisabledStatus.String()
cases := []struct {
desc string
@@ -1345,7 +1336,7 @@ func TestDisableChannel(t *testing.T) {
channelID: channel.ID,
svcRes: channels.Channel{
ID: generateUUID(t),
Metadata: clients.Metadata{
Metadata: channels.Metadata{
"test": make(chan int),
},
},
@@ -2091,6 +2082,7 @@ func generateTestChannel(t *testing.T) sdk.Channel {
Metadata: sdk.Metadata{"role": "client"},
CreatedAt: createdAt,
UpdatedAt: updatedAt,
Status: channels.EnabledStatus.String(),
}
return ch
}
+4 -3
View File
@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/absmach/supermq/channels"
mgchannels "github.com/absmach/supermq/channels"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
@@ -198,9 +199,9 @@ func convertClient(c sdk.Client) clients.Client {
func convertChannel(g sdk.Channel) mgchannels.Channel {
if g.Status == "" {
g.Status = clients.EnabledStatus.String()
g.Status = channels.EnabledStatus.String()
}
status, err := clients.ToStatus(g.Status)
status, err := channels.ToStatus(g.Status)
if err != nil {
return mgchannels.Channel{}
}
@@ -210,7 +211,7 @@ func convertChannel(g sdk.Channel) mgchannels.Channel {
Tags: g.Tags,
ParentGroup: g.ParentGroup,
Domain: g.DomainID,
Metadata: clients.Metadata(g.Metadata),
Metadata: channels.Metadata(g.Metadata),
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
UpdatedBy: g.UpdatedBy,