SMQ-2856 - Add tags to groups (#2860)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2025-05-16 21:09:16 +03:00
committed by GitHub
parent 9a9e22fbce
commit 86a8f49e82
28 changed files with 1025 additions and 43 deletions
+65
View File
@@ -197,6 +197,44 @@ paths:
"500":
$ref: "#/components/responses/ServiceError"
/{domainID}/groups/{groupID}/tags:
patch:
operationId: updateGroupTags
summary: Updates group tags.
description: |
Update is performed by replacing the current resource data with values
provided in a request payload. Note that the group's ID will not be
affected.
tags:
- Groups
parameters:
- $ref: "auth.yaml#/components/parameters/DomainID"
- $ref: "#/components/parameters/GroupID"
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/GroupUpdateTagsReq"
responses:
"200":
$ref: "#/components/responses/GroupRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Group does not exist.
"409":
description: Failed due to using an existing identity.
"415":
description: Missing or invalid content type.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/{domainID}/groups/{groupID}/enable:
post:
operationId: enableGroup
@@ -1038,6 +1076,12 @@ components:
type: string
example: long group description
description: Group description, free form text.
tags:
type: array
items:
type: string
example: ["tag1", "tag2"]
description: Group tags.
metadata:
type: object
example: { "role": "general" }
@@ -1219,6 +1263,19 @@ components:
- metadata
- description
GroupUpdateTags:
type: object
properties:
tags:
type: array
minItems: 0
items:
type: string
example: ["tag1", "tag2"]
description: Group tags.
required:
- tags
ParentGroupReqObj:
type: object
properties:
@@ -1490,6 +1547,14 @@ components:
schema:
$ref: "#/components/schemas/GroupUpdate"
GroupUpdateTagsReq:
description: JSON-formated document describing the tags of group to be updated.
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/GroupUpdateTags"
GroupParentReq:
description: JSON-formated document describing the parent group to be set to or removed from a group.
required: true
+22 -3
View File
@@ -11,6 +11,8 @@ import (
"github.com/spf13/cobra"
)
const tags = "tags"
var cmdGroups = []cobra.Command{
{
Use: "create <JSON_group> <domain_id> <user_auth_token>",
@@ -38,18 +40,35 @@ var cmdGroups = []cobra.Command{
},
},
{
Use: "update <JSON_group> <domain_id> <user_auth_token>",
Use: "update [<JSON_group> <domain_id> | tags <group_id> <tags> ] <user_auth_token>",
Short: "Update group",
Long: "Updates group\n" +
"Usage:\n" +
"\tsupermq-cli groups update '{\"id\":\"<group_id>\", \"name\":\"new group\", \"description\":\"new group description\", \"metadata\":{\"key\": \"value\"}}' $DOMAINID $USERTOKEN\n",
"\tsupermq-cli groups update '{\"id\":\"<group_id>\", \"name\":\"new group\", \"description\":\"new group description\", \"metadata\":{\"key\": \"value\"}}' $DOMAINID $USERTOKEN\n" +
"\tsupermq-cli groups update tags <group_id> '{\"tag1\":\"value1\", \"tag2\":\"value2\"}' $DOMAINID $USERTOKEN\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 {
if len(args) != 3 && len(args) != 5 {
logUsageCmd(*cmd, cmd.Use)
return
}
var group smqsdk.Group
if args[0] == tags {
if err := json.Unmarshal([]byte(args[2]), &group.Tags); err != nil {
logErrorCmd(*cmd, err)
return
}
group.ID = args[1]
group, err := sdk.UpdateGroupTags(cmd.Context(), group, args[3], args[4])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, group)
return
}
if err := json.Unmarshal([]byte(args[0]), &group); err != nil {
logErrorCmd(*cmd, err)
return
+53
View File
@@ -20,6 +20,11 @@ import (
"github.com/stretchr/testify/mock"
)
const (
tagUpdateType = "tags"
newTagsJson = "[\"tag1\", \"tag2\"]"
)
var group = smqsdk.Group{
ID: testsutil.GenerateUUID(&testing.T{}),
Name: "testgroup",
@@ -196,6 +201,8 @@ func TestUpdategroupCmd(t *testing.T) {
groupCmd := cli.NewGroupsCmd()
rootCmd := setFlags(groupCmd)
newTagString := []string{"tag1", "tag2"}
newGroupJson := fmt.Sprintf("{\"id\":\"%s\",\"name\" : \"newgroup\"}", group.ID)
cases := []struct {
desc string
@@ -250,11 +257,56 @@ func TestUpdategroupCmd(t *testing.T) {
errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")),
logType: errLog,
},
{
desc: "update group tags successfully",
args: []string{
tagUpdateType,
group.ID,
newTagsJson,
domainID,
token,
},
group: smqsdk.Group{
Name: group.Name,
ID: group.ID,
DomainID: group.DomainID,
Status: group.Status,
Tags: newTagString,
},
logType: entityLog,
},
{
desc: "update group with invalid tags",
args: []string{
tagUpdateType,
group.ID,
"[\"tag1\", \"tag2\"",
domainID,
token,
},
logType: errLog,
sdkErr: errors.NewSDKError(errors.New("unexpected end of JSON input")),
errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")),
},
{
desc: "update group tags with invalid group id",
args: []string{
tagUpdateType,
invalidID,
newTagsJson,
domainID,
token,
},
sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)),
logType: errLog,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var ch smqsdk.Group
sdkCall := sdkMock.On("UpdateGroup", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.group, tc.sdkErr)
sdkCall1 := sdkMock.On("UpdateGroupTags", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.group, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{updCmd}, tc.args...)...)
switch tc.logType {
@@ -268,6 +320,7 @@ func TestUpdategroupCmd(t *testing.T) {
assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out))
}
sdkCall.Unset()
sdkCall1.Unset()
})
}
}
+15
View File
@@ -68,6 +68,21 @@ func DecodeGroupUpdate(_ context.Context, r *http.Request) (interface{}, error)
return req, nil
}
func decodeUpdateGroupTags(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := updateGroupTagsReq{
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
}
return req, nil
}
func DecodeGroupRequest(_ context.Context, r *http.Request) (interface{}, error) {
roles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false)
if err != nil {
+140
View File
@@ -474,6 +474,146 @@ func TestUpdateGroupEndpoint(t *testing.T) {
}
}
func TestUpdateGroupTagsEndpoint(t *testing.T) {
gs, svc, authn := newGroupsServer()
defer gs.Close()
newTag := "newtag"
cases := []struct {
desc string
token string
id string
domainID string
data string
contentType string
session smqauthn.Session
svcResp groups.Group
svcErr error
resp groups.Group
status int
authnErr error
err error
}{
{
desc: "update group tags successfully",
token: validToken,
domainID: validID,
id: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: contentType,
svcResp: validGroupResp,
status: http.StatusOK,
err: nil,
},
{
desc: "update group tags with invalid token",
token: invalidToken,
session: smqauthn.Session{},
domainID: validID,
id: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: contentType,
authnErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "update group tags with empty token",
token: "",
session: smqauthn.Session{},
domainID: validID,
id: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: contentType,
status: http.StatusUnauthorized,
err: apiutil.ErrBearerToken,
},
{
desc: "update group tags with empty domainID",
token: validToken,
id: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrMissingDomainID,
},
{
desc: "update group tags with invalid content type",
token: validToken,
id: validID,
domainID: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: "application/xml",
svcResp: validGroupResp,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update group tags with service error",
token: validToken,
id: validID,
domainID: validID,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
contentType: contentType,
svcResp: groups.Group{},
svcErr: svcerr.ErrAuthorization,
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
},
{
desc: "update group with malformed request",
token: validToken,
id: validID,
domainID: validID,
contentType: contentType,
data: fmt.Sprintf(`{"tags":["%s"}`, newTag),
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
},
{
desc: "update group with empty id",
token: validToken,
id: "",
domainID: validID,
contentType: contentType,
data: fmt.Sprintf(`{"tags":["%s"]}`, newTag),
status: http.StatusBadRequest,
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
client: gs.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/%s/groups/%s/tags", gs.URL, tc.domainID, tc.id),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
if tc.token == validToken {
tc.session = smqauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID}
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
svcCall := svc.On("UpdateGroupTags", mock.Anything, tc.session, groups.Group{ID: tc.id, Tags: []string{newTag}}).Return(tc.svcResp, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var errRes respBody
err = json.NewDecoder(res.Body).Decode(&errRes)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if errRes.Err != "" || errRes.Message != "" {
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authCall.Unset()
})
}
}
func TestEnableGroupEndpoint(t *testing.T) {
gs, svc, authn := newGroupsServer()
defer gs.Close()
+25
View File
@@ -85,6 +85,31 @@ func UpdateGroupEndpoint(svc groups.Service) endpoint.Endpoint {
}
}
func updateGroupTagsEndpoint(svc groups.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(updateGroupTagsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
g := groups.Group{
ID: req.id,
Tags: req.Tags,
}
g, err := svc.UpdateGroupTags(ctx, session, g)
if err != nil {
return nil, err
}
return updateGroupRes{Group: g}, nil
}
}
func EnableGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(changeGroupStatusReq)
+13
View File
@@ -38,6 +38,19 @@ func (req updateGroupReq) validate() error {
return nil
}
type updateGroupTagsReq struct {
id string
Tags []string `json:"tags,omitempty"`
}
func (req updateGroupTagsReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type listGroupsReq struct {
groups.PageMeta
userID string
+7
View File
@@ -59,6 +59,13 @@ func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, l
opts...,
), "update_group").ServeHTTP)
r.Patch("/tags", otelhttp.NewHandler(kithttp.NewServer(
updateGroupTagsEndpoint(svc),
decodeUpdateGroupTags,
api.EncodeResponse,
opts...,
), "update_group_tags").ServeHTTP)
r.Delete("/", otelhttp.NewHandler(kithttp.NewServer(
DeleteGroupEndpoint(svc),
DecodeGroupRequest,
+4 -1
View File
@@ -16,6 +16,7 @@ const (
groupPrefix = "group."
groupCreate = groupPrefix + "create"
groupUpdate = groupPrefix + "update"
groupUpdateTags = groupPrefix + "update_tags"
groupEnable = groupPrefix + "enable"
groupDisable = groupPrefix + "disable"
groupView = groupPrefix + "view"
@@ -93,16 +94,18 @@ func (cge createGroupEvent) Encode() (map[string]interface{}, error) {
type updateGroupEvent struct {
groups.Group
authn.Session
operation string
requestID string
}
func (uge updateGroupEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": groupUpdate,
"operation": uge.operation,
"updated_at": uge.UpdatedAt,
"updated_by": uge.UpdatedBy,
"domain": uge.DomainID,
"user_id": uge.UserID,
"tags": uge.Tags,
"token_type": uge.Type.String(),
"super_admin": uge.SuperAdmin,
"request_id": uge.requestID,
+24 -3
View File
@@ -19,6 +19,7 @@ const (
supermqPrefix = "supermq."
createStream = supermqPrefix + groupCreate
updateStream = supermqPrefix + groupUpdate
updateTagsStream = supermqPrefix + groupUpdateTags
enableStream = supermqPrefix + groupEnable
disableStream = supermqPrefix + groupDisable
viewStream = supermqPrefix + groupView
@@ -86,9 +87,10 @@ func (es eventStore) UpdateGroup(ctx context.Context, session authn.Session, gro
}
event := updateGroupEvent{
group,
session,
middleware.GetReqID(ctx),
Group: group,
Session: session,
operation: groupUpdate,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, updateStream, event); err != nil {
@@ -98,6 +100,25 @@ func (es eventStore) UpdateGroup(ctx context.Context, session authn.Session, gro
return group, nil
}
func (es *eventStore) UpdateGroupTags(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) {
g, err := es.svc.UpdateGroupTags(ctx, session, g)
if err != nil {
return g, err
}
event := updateGroupEvent{
Group: g,
Session: session,
operation: groupUpdateTags,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, updateTagsStream, event); err != nil {
return g, err
}
return g, nil
}
func (es eventStore) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (groups.Group, error) {
group, err := es.svc.ViewGroup(ctx, session, id, withRoles)
if err != nil {
+7
View File
@@ -30,6 +30,7 @@ type Group struct {
Parent string `json:"parent_id,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Tags []string `json:"tags,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Level int `json:"level,omitempty"`
Path string `json:"path,omitempty"`
@@ -91,6 +92,9 @@ type Repository interface {
// Update a group.
Update(ctx context.Context, g Group) (Group, error)
// Update a group's tags.
UpdateTags(ctx context.Context, g Group) (Group, error)
// RetrieveByID retrieves group by its id.
RetrieveByID(ctx context.Context, id string) (Group, error)
@@ -140,6 +144,9 @@ type Service interface {
// UpdateGroup updates the group identified by the provided ID.
UpdateGroup(ctx context.Context, session authn.Session, g Group) (Group, error)
// UpdateGroupTags updates the groups's tags.
UpdateGroupTags(ctx context.Context, session authn.Session, group Group) (Group, error)
// ViewGroup retrieves data about the group identified by ID.
ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (Group, error)
+27
View File
@@ -22,6 +22,7 @@ import (
var (
errView = errors.New("not authorized to view group")
errUpdate = errors.New("not authorized to update group")
errUpdateTags = errors.New("not authorized to update group tags")
errEnable = errors.New("not authorized to enable group")
errDisable = errors.New("not authorized to disable group")
errDelete = errors.New("not authorized to delete group")
@@ -150,6 +151,32 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth
return am.svc.UpdateGroup(ctx, session, g)
}
func (am *authorizationMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: group.ID,
}); err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
if err := am.authorize(ctx, groups.OpUpdateGroupTags, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.GroupType,
Object: group.ID,
}); err != nil {
return groups.Group{}, errors.Wrap(errUpdateTags, err)
}
return am.svc.UpdateGroupTags(ctx, session, group)
}
func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
+22
View File
@@ -75,6 +75,28 @@ func (lm *loggingMiddleware) UpdateGroup(ctx context.Context, session authn.Sess
return lm.svc.UpdateGroup(ctx, session, group)
}
func (lm *loggingMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (g groups.Group, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("group",
slog.String("id", g.ID),
slog.String("name", g.Name),
slog.Any("tags", g.Tags),
),
}
if err != nil {
args := append(args, slog.String("error", err.Error()))
lm.logger.Warn("Update group tags failed", args...)
return
}
lm.logger.Info("Update group tags completed successfully", args...)
}(time.Now())
return lm.svc.UpdateGroupTags(ctx, session, group)
}
// ViewGroup logs the view_group request. It logs the group name, id and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (g groups.Group, err error) {
+9
View File
@@ -52,6 +52,15 @@ func (ms *metricsMiddleware) UpdateGroup(ctx context.Context, session authn.Sess
return ms.svc.UpdateGroup(ctx, session, group)
}
// UpdateGroupTags instruments UpdateGroupTags method with metrics.
func (ms *metricsMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
defer func(begin time.Time) {
ms.counter.With("method", "update_group_tags").Add(1)
ms.latency.With("method", "update_group_tags").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.UpdateGroupTags(ctx, session, group)
}
// ViewGroup instruments ViewGroup method with metrics.
func (ms *metricsMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (g groups.Group, err error) {
defer func(begin time.Time) {
+55
View File
@@ -2009,3 +2009,58 @@ func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context,
_c.Call.Return(run)
return _c
}
// UpdateTags provides a mock function for the type Repository
func (_mock *Repository) UpdateTags(ctx context.Context, g groups.Group) (groups.Group, error) {
ret := _mock.Called(ctx, g)
if len(ret) == 0 {
panic("no return value specified for UpdateTags")
}
var r0 groups.Group
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
return returnFunc(ctx, g)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
r0 = returnFunc(ctx, g)
} else {
r0 = ret.Get(0).(groups.Group)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
r1 = returnFunc(ctx, g)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_UpdateTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTags'
type Repository_UpdateTags_Call struct {
*mock.Call
}
// UpdateTags is a helper method to define mock.On call
// - ctx
// - g
func (_e *Repository_Expecter) UpdateTags(ctx interface{}, g interface{}) *Repository_UpdateTags_Call {
return &Repository_UpdateTags_Call{Call: _e.mock.On("UpdateTags", ctx, g)}
}
func (_c *Repository_UpdateTags_Call) Run(run func(ctx context.Context, g groups.Group)) *Repository_UpdateTags_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(groups.Group))
})
return _c
}
func (_c *Repository_UpdateTags_Call) Return(group groups.Group, err error) *Repository_UpdateTags_Call {
_c.Call.Return(group, err)
return _c
}
func (_c *Repository_UpdateTags_Call) RunAndReturn(run func(ctx context.Context, g groups.Group) (groups.Group, error)) *Repository_UpdateTags_Call {
_c.Call.Return(run)
return _c
}
+56
View File
@@ -1768,6 +1768,62 @@ func (_c *Service_UpdateGroup_Call) RunAndReturn(run func(ctx context.Context, s
return _c
}
// UpdateGroupTags provides a mock function for the type Service
func (_mock *Service) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
ret := _mock.Called(ctx, session, group)
if len(ret) == 0 {
panic("no return value specified for UpdateGroupTags")
}
var r0 groups.Group
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) (groups.Group, error)); ok {
return returnFunc(ctx, session, group)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) groups.Group); ok {
r0 = returnFunc(ctx, session, group)
} else {
r0 = ret.Get(0).(groups.Group)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, groups.Group) error); ok {
r1 = returnFunc(ctx, session, group)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_UpdateGroupTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateGroupTags'
type Service_UpdateGroupTags_Call struct {
*mock.Call
}
// UpdateGroupTags is a helper method to define mock.On call
// - ctx
// - session
// - group
func (_e *Service_Expecter) UpdateGroupTags(ctx interface{}, session interface{}, group interface{}) *Service_UpdateGroupTags_Call {
return &Service_UpdateGroupTags_Call{Call: _e.mock.On("UpdateGroupTags", ctx, session, group)}
}
func (_c *Service_UpdateGroupTags_Call) Run(run func(ctx context.Context, session authn.Session, group groups.Group)) *Service_UpdateGroupTags_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(authn.Session), args[2].(groups.Group))
})
return _c
}
func (_c *Service_UpdateGroupTags_Call) Return(group1 groups.Group, err error) *Service_UpdateGroupTags_Call {
_c.Call.Return(group1, err)
return _c
}
func (_c *Service_UpdateGroupTags_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error)) *Service_UpdateGroupTags_Call {
_c.Call.Return(run)
return _c
}
// UpdateRoleName provides a mock function for the type Service
func (_mock *Service) UpdateRoleName(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string) (roles.Role, error) {
ret := _mock.Called(ctx, session, entityID, roleID, newRoleName)
+86 -35
View File
@@ -18,6 +18,7 @@ import (
"github.com/absmach/supermq/pkg/postgres"
"github.com/absmach/supermq/pkg/roles"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
@@ -95,7 +96,7 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
g.Status = groups.EnabledStatus
q := fmt.Sprintf(`UPDATE groups SET %s updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, name, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, updated_at, updated_by, status`, upq)
RETURNING id, name, tags, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, updated_at, updated_by, status`, upq)
dbu, err := toDBGroup(g)
if err != nil {
@@ -118,9 +119,38 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
return toGroup(dbu)
}
func (repo groupRepository) UpdateTags(ctx context.Context, group groups.Group) (groups.Group, error) {
q := `UPDATE groups SET tags = :tags, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, name, tags, metadata, COALESCE(domain_id, '') AS domain_id, COALESCE(parent_id, '') AS parent_id, status, created_at, updated_at, updated_by`
group.Status = groups.EnabledStatus
dbg, err := toDBGroup(group)
if err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
dbg = dbGroup{}
if row.Next() {
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return toGroup(dbg)
}
return groups.Group{}, repoerr.ErrNotFound
}
func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group) (groups.Group, error) {
qc := `UPDATE groups SET status = :status, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id
RETURNING id, name, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, updated_at, updated_by, status`
RETURNING id, name, tags, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, updated_at, updated_by, status`
dbg, err := toDBGroup(group)
if err != nil {
@@ -143,7 +173,7 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group
}
func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (groups.Group, error) {
q := `SELECT id, name, domain_id, COALESCE(parent_id, '') AS parent_id, description, metadata, created_at, updated_at, updated_by, status, path FROM groups
q := `SELECT id, name, tags, domain_id, COALESCE(parent_id, '') AS parent_id, description, metadata, created_at, updated_at, updated_by, status, path FROM groups
WHERE id = :id`
dbg := dbGroup{
@@ -289,6 +319,7 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
g.parent_id,
g.domain_id,
g.name,
g.tags,
g.description,
g.path,
g.metadata,
@@ -334,6 +365,7 @@ func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, u
g.name,
g.domain_id,
COALESCE(g.parent_id, '') AS parent_id,
g.tags,
g.description,
g.metadata,
g.created_at,
@@ -378,7 +410,7 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
var q string
query := buildQuery(pm)
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
dbPageMeta, err := toDBGroupPageMeta(pm)
@@ -398,7 +430,7 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
FROM (
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
) AS subquery;
`, query)
@@ -421,7 +453,7 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
}
query := buildQuery(pm, ids...)
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
q = fmt.Sprintf(`SELECT DISTINCT g.id, g.domain_id, tags, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s ORDER BY g.created_at LIMIT :limit OFFSET :offset;`, query)
dbPageMeta, err := toDBGroupPageMeta(pm)
@@ -441,7 +473,7 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
FROM (
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.description,
SELECT DISTINCT g.id, g.domain_id, COALESCE(g.parent_id, '') AS parent_id, g.name, g.tags, g.description,
g.metadata, g.created_at, g.updated_at, g.updated_by, g.status FROM groups g %s
) AS subquery;
`, query)
@@ -469,6 +501,7 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, id string, hm
g.domain_id,
g.name,
g.description,
g.tags,
g.metadata,
g.created_at,
g.updated_at,
@@ -491,6 +524,7 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, id string, hm
COALESCE(g.parent_id, '') AS parent_id,
g.domain_id,
g.name,
g.tags,
g.description,
g.metadata,
g.created_at,
@@ -801,6 +835,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
g.domain_id,
COALESCE(g.parent_id, '') AS parent_id,
g.description,
g.tags,
g.metadata,
g.created_at,
g.updated_at,
@@ -848,6 +883,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
g.domain_id,
COALESCE(g.parent_id, '') AS parent_id,
g.description,
g.tags,
g.metadata,
g.created_at,
g.updated_at,
@@ -940,6 +976,7 @@ direct_indirect_groups as (
parent_id,
domain_id,
"name",
tags,
description,
metadata,
created_at,
@@ -963,6 +1000,7 @@ direct_indirect_groups as (
parent_id,
domain_id,
"name",
tags,
description,
metadata,
created_at,
@@ -987,6 +1025,7 @@ final_groups AS (
dig.parent_id,
dig.domain_id,
dig."name",
dig.tags,
dig.description,
dig.metadata,
dig.created_at,
@@ -1010,6 +1049,7 @@ final_groups AS (
dg.parent_id,
dg.domain_id,
dg."name",
dg.tags,
dg.description,
dg.metadata,
dg.created_at,
@@ -1093,28 +1133,29 @@ func buildQuery(gm groups.PageMeta, ids ...string) string {
}
type dbGroup struct {
ID string `db:"id"`
ParentID *string `db:"parent_id,omitempty"`
DomainID string `db:"domain_id,omitempty"`
Name string `db:"name"`
Description string `db:"description,omitempty"`
Level int `db:"level"`
Path string `db:"path,omitempty"`
Metadata []byte `db:"metadata,omitempty"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status groups.Status `db:"status"`
RoleID string `db:"role_id"`
RoleName string `db:"role_name"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
AccessProviderId string `db:"access_provider_id"`
AccessProviderRoleId string `db:"access_provider_role_id"`
AccessProviderRoleName string `db:"access_provider_role_name"`
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
ID string `db:"id"`
ParentID *string `db:"parent_id,omitempty"`
DomainID string `db:"domain_id,omitempty"`
Name string `db:"name"`
Description string `db:"description,omitempty"`
Tags pgtype.TextArray `db:"tags,omitempty"`
Level int `db:"level"`
Path string `db:"path,omitempty"`
Metadata []byte `db:"metadata,omitempty"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
UpdatedBy *string `db:"updated_by,omitempty"`
Status groups.Status `db:"status"`
RoleID string `db:"role_id"`
RoleName string `db:"role_name"`
Actions pq.StringArray `db:"actions"`
AccessType string `db:"access_type"`
AccessProviderId string `db:"access_provider_id"`
AccessProviderRoleId string `db:"access_provider_role_id"`
AccessProviderRoleName string `db:"access_provider_role_name"`
AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions"`
MemberID string `db:"member_id,omitempty"`
Roles json.RawMessage `db:"roles,omitempty"`
}
func toDBGroup(g groups.Group) (dbGroup, error) {
@@ -1126,6 +1167,10 @@ func toDBGroup(g groups.Group) (dbGroup, error) {
}
data = b
}
var tags pgtype.TextArray
if err := tags.Set(g.Tags); err != nil {
return dbGroup{}, err
}
var parentID *string
if g.Parent != "" {
parentID = &g.Parent
@@ -1144,6 +1189,7 @@ func toDBGroup(g groups.Group) (dbGroup, error) {
ParentID: parentID,
DomainID: g.Domain,
Description: g.Description,
Tags: tags,
Metadata: data,
Path: g.Path,
CreatedAt: g.CreatedAt,
@@ -1160,6 +1206,10 @@ func toGroup(g dbGroup) (groups.Group, error) {
return groups.Group{}, errors.Wrap(repoerr.ErrMalformedEntity, err)
}
}
var tags []string
for _, e := range g.Tags.Elements {
tags = append(tags, e.String)
}
var parentID string
if g.ParentID != nil {
parentID = *g.ParentID
@@ -1186,6 +1236,7 @@ func toGroup(g dbGroup) (groups.Group, error) {
Parent: parentID,
Domain: g.DomainID,
Description: g.Description,
Tags: tags,
Metadata: metadata,
Level: g.Level,
Path: g.Path,
@@ -1276,12 +1327,12 @@ func (repo groupRepository) getInsertQuery(c context.Context, g groups.Group) (s
if len(strings.Split(path, ".")) > groups.MaxPathLength {
return "", fmt.Errorf("reached max nested depth")
}
return fmt.Sprintf(`INSERT INTO groups (name, description, id, domain_id, parent_id, metadata, created_at, status, path)
VALUES (:name, :description, :id, :domain_id, :parent_id, :metadata, :created_at, :status, '%s')
RETURNING id, name, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path), nil
return fmt.Sprintf(`INSERT INTO groups (name, description, tags, id, domain_id, parent_id, metadata, created_at, status, path)
VALUES (:name, :description, :tags, :id, :domain_id, :parent_id, :metadata, :created_at, :status, '%s')
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, path), nil
default:
return `INSERT INTO groups (name, description, id, domain_id, metadata, created_at, status, path)
VALUES (:name, :description, :id, :domain_id, :metadata, :created_at, :status, :id)
RETURNING id, name, description, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, nil
return `INSERT INTO groups (name, description, tags, id, domain_id, metadata, created_at, status, path)
VALUES (:name, :description, :tags, :id, :domain_id, :metadata, :created_at, :status, :id)
RETURNING id, name, description, tags, domain_id, COALESCE(parent_id, '') AS parent_id, metadata, created_at, status, path, nlevel(path) as level;`, nil
}
}
+62
View File
@@ -29,6 +29,7 @@ var (
ID: testsutil.GenerateUUID(&testing.T{}),
Domain: testsutil.GenerateUUID(&testing.T{}),
Name: namegen.Generate(),
Tags: []string{"tag1", "tag2"},
Description: strings.Repeat("a", 64),
Metadata: map[string]interface{}{"key": "value"},
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
@@ -352,6 +353,67 @@ func TestUpdate(t *testing.T) {
}
}
func TestUpdateTags(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM groups")
require.Nil(t, err, fmt.Sprintf("clean groups unexpected error: %s", err))
})
repo := postgres.New(database)
_, err := repo.Save(context.Background(), validGroup)
require.Nil(t, err, fmt.Sprintf("save group unexpected error: %s", err))
cases := []struct {
desc string
group groups.Group
err error
}{
{
desc: "update group tags",
group: groups.Group{
ID: validGroup.ID,
Tags: []string{"tag3", "tag4"},
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
err: nil,
},
{
desc: "update group with invalid ID",
group: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag3", "tag4"},
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
err: repoerr.ErrNotFound,
},
{
desc: "update group with empty ID",
group: groups.Group{
Tags: []string{"tag3", "tag4"},
UpdatedAt: validTimestamp,
UpdatedBy: testsutil.GenerateUUID(t),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
group, err := repo.UpdateTags(context.Background(), tc.group)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.group.ID, group.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.ID, group.ID))
assert.Equal(t, tc.group.UpdatedAt, group.UpdatedAt, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedAt, group.UpdatedAt))
assert.Equal(t, tc.group.UpdatedBy, group.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.UpdatedBy, group.UpdatedBy))
assert.Equal(t, tc.group.Tags, group.Tags, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.group.Tags, group.Tags))
}
})
}
}
func TestChangeStatus(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM groups")
+9
View File
@@ -64,6 +64,15 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
`ALTER TABLE groups ADD CONSTRAINT groups_domain_id_name_key UNIQUE (domain_id, name)`,
},
},
{
Id: "groups_04",
Up: []string{
`ALTER TABLE groups ADD COLUMN tags TEXT[]`,
},
Down: []string{
`ALTER TABLE groups DROP COLUMN tags`,
},
},
},
}
+4
View File
@@ -13,6 +13,7 @@ import (
const (
OpViewGroup svcutil.Operation = iota
OpUpdateGroup
OpUpdateGroupTags
OpEnableGroup
OpDisableGroup
OpRetrieveGroupHierarchy
@@ -28,6 +29,7 @@ const (
var expectedOperations = []svcutil.Operation{
OpViewGroup,
OpUpdateGroup,
OpUpdateGroupTags,
OpEnableGroup,
OpDisableGroup,
OpRetrieveGroupHierarchy,
@@ -43,6 +45,7 @@ var expectedOperations = []svcutil.Operation{
var operationNames = []string{
"OpViewGroup",
"OpUpdateGroup",
"OpUpdateGroupTags",
"OpEnableGroup",
"OpDisableGroup",
"OpRetrieveGroupHierarchy",
@@ -107,6 +110,7 @@ func NewOperationPermissionMap() map[svcutil.Operation]svcutil.Permission {
opPerm := map[svcutil.Operation]svcutil.Permission{
OpViewGroup: readPermission,
OpUpdateGroup: updatePermission,
OpUpdateGroupTags: updatePermission,
OpEnableGroup: updatePermission,
OpDisableGroup: updatePermission,
OpRetrieveGroupHierarchy: readPermission,
+14
View File
@@ -158,6 +158,20 @@ func (svc service) UpdateGroup(ctx context.Context, session smqauthn.Session, g
return group, nil
}
func (svc service) UpdateGroupTags(ctx context.Context, session smqauthn.Session, g Group) (Group, error) {
group := Group{
ID: g.ID,
Tags: g.Tags,
UpdatedAt: time.Now(),
UpdatedBy: session.UserID,
}
group, err := svc.repo.UpdateTags(ctx, group)
if err != nil {
return Group{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return group, nil
}
func (svc service) EnableGroup(ctx context.Context, session smqauthn.Session, id string) (Group, error) {
group := Group{
ID: id,
+47
View File
@@ -342,6 +342,53 @@ func TestUpdateGroup(t *testing.T) {
}
}
func TestUpdateGroupTags(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
updateReq groups.Group
repoResp groups.Group
repoErr error
err error
}{
{
desc: "update group tags successfully",
updateReq: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
repoResp: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
},
{
desc: "update group tags with repo error",
updateReq: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
repoErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("UpdateTags", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
got, err := svc.UpdateGroupTags(context.Background(), validSession, tc.updateReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
assert.Equal(t, tc.repoResp, got)
ok := repo.AssertCalled(t, "UpdateTags", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateTags was not called on %s", tc.desc))
}
repoCall.Unset()
})
}
}
func TestEnableGroup(t *testing.T) {
svc := newService(t)
+11
View File
@@ -89,6 +89,17 @@ func (tm *tracingMiddleware) UpdateGroup(ctx context.Context, session authn.Sess
return tm.svc.UpdateGroup(ctx, session, g)
}
// UpdateGroupTags traces the "UpdateGroupTags" operation of the wrapped groups.Service.
func (tm *tracingMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_update_group_tags", trace.WithAttributes(
attribute.String("id", group.ID),
attribute.StringSlice("tags", group.Tags),
))
defer span.End()
return tm.svc.UpdateGroupTags(ctx, session, group)
}
// EnableGroup traces the "EnableGroup" operation of the wrapped groups.Service.
func (tm *tracingMiddleware) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_enable_group", trace.WithAttributes(attribute.String("id", id)))
+25
View File
@@ -32,6 +32,7 @@ type Group struct {
ParentID string `json:"parent_id,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Tags []string `json:"tags,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Level int `json:"level,omitempty"`
Path string `json:"path,omitempty"`
@@ -135,6 +136,30 @@ func (sdk mgSDK) UpdateGroup(ctx context.Context, g Group, domainID, token strin
return g, nil
}
func (sdk mgSDK) UpdateGroupTags(ctx context.Context, g Group, domainID, token string) (Group, errors.SDKError) {
if g.ID == "" {
return Group{}, errors.NewSDKError(apiutil.ErrMissingID)
}
url := fmt.Sprintf("%s/%s/%s/%s/tags", sdk.groupsURL, domainID, groupsEndpoint, g.ID)
data, err := json.Marshal(g)
if err != nil {
return Group{}, errors.NewSDKError(err)
}
_, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK)
if sdkerr != nil {
return Group{}, sdkerr
}
g = Group{}
if err := json.Unmarshal(body, &g); err != nil {
return Group{}, errors.NewSDKError(err)
}
return g, nil
}
func (sdk mgSDK) SetGroupParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError {
scpg := groupParentReq{ParentID: groupID}
data, err := json.Marshal(scpg)
+151
View File
@@ -850,6 +850,157 @@ func TestUpdateGroup(t *testing.T) {
}
}
func TestUpdateGroupTags(t *testing.T) {
ts, tsvc, auth := setupGroups()
defer ts.Close()
sdkGroup := generateTestGroup(t)
updatedGroup := sdkGroup
updatedGroup.Tags = []string{"newTag1", "newTag2"}
updateGroupReq := sdk.Group{
ID: sdkGroup.ID,
Tags: updatedGroup.Tags,
}
conf := sdk.Config{
GroupsURL: ts.URL,
}
mgsdk := sdk.NewSDK(conf)
cases := []struct {
desc string
domainID string
token string
session smqauthn.Session
updateGroupReq sdk.Group
svcReq groups.Group
svcRes groups.Group
svcErr error
authenticateErr error
response sdk.Group
err errors.SDKError
}{
{
desc: "update group tags successfully",
domainID: domainID,
token: validToken,
updateGroupReq: updateGroupReq,
svcReq: convertGroup(updateGroupReq),
svcRes: convertGroup(updatedGroup),
svcErr: nil,
response: updatedGroup,
err: nil,
},
{
desc: "update group tags with an invalid token",
domainID: domainID,
token: invalidToken,
updateGroupReq: updateGroupReq,
svcReq: convertGroup(updateGroupReq),
svcRes: groups.Group{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
},
{
desc: "update group tags with empty token",
domainID: domainID,
token: "",
updateGroupReq: updateGroupReq,
svcReq: convertGroup(updateGroupReq),
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "update group tags with an invalid group id",
domainID: domainID,
token: validToken,
updateGroupReq: sdk.Group{
ID: wrongID,
Tags: updatedGroup.Tags,
},
svcReq: convertGroup(sdk.Group{
ID: wrongID,
Tags: updatedGroup.Tags,
}),
svcRes: groups.Group{},
svcErr: svcerr.ErrUpdateEntity,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrUpdateEntity, http.StatusUnprocessableEntity),
},
{
desc: "update group tags with empty group id",
domainID: domainID,
token: validToken,
updateGroupReq: sdk.Group{
ID: "",
Tags: updatedGroup.Tags,
},
svcReq: convertGroup(sdk.Group{
ID: "",
Tags: updatedGroup.Tags,
}),
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(apiutil.ErrMissingID),
},
{
desc: "update group tags with a request that can't be marshalled",
domainID: domainID,
token: validToken,
updateGroupReq: sdk.Group{
ID: "test",
Metadata: map[string]interface{}{
"test": make(chan int),
},
},
svcReq: groups.Group{},
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "update group tags with a response that can't be unmarshalled",
domainID: domainID,
token: validToken,
updateGroupReq: updateGroupReq,
svcReq: convertGroup(updateGroupReq),
svcRes: groups.Group{
Name: updatedGroup.Name,
Tags: updatedGroup.Tags,
Metadata: groups.Metadata{
"test": make(chan int),
},
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("UpdateGroupTags", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.UpdateGroupTags(context.Background(), tc.updateGroupReq, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "UpdateGroupTags", mock.Anything, tc.session, tc.svcReq)
assert.True(t, ok)
}
svcCall.Unset()
authCall.Unset()
})
}
}
func TestEnableGroup(t *testing.T) {
ts, gsvc, auth := setupGroups()
defer ts.Close()
+59
View File
@@ -6489,6 +6489,65 @@ func (_c *SDK_UpdateGroupRole_Call) RunAndReturn(run func(ctx context.Context, i
return _c
}
// UpdateGroupTags provides a mock function for the type SDK
func (_mock *SDK) UpdateGroupTags(ctx context.Context, group sdk.Group, domainID string, token string) (sdk.Group, errors.SDKError) {
ret := _mock.Called(ctx, group, domainID, token)
if len(ret) == 0 {
panic("no return value specified for UpdateGroupTags")
}
var r0 sdk.Group
var r1 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Group, string, string) (sdk.Group, errors.SDKError)); ok {
return returnFunc(ctx, group, domainID, token)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Group, string, string) sdk.Group); ok {
r0 = returnFunc(ctx, group, domainID, token)
} else {
r0 = ret.Get(0).(sdk.Group)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Group, string, string) errors.SDKError); ok {
r1 = returnFunc(ctx, group, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
}
}
return r0, r1
}
// SDK_UpdateGroupTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateGroupTags'
type SDK_UpdateGroupTags_Call struct {
*mock.Call
}
// UpdateGroupTags is a helper method to define mock.On call
// - ctx
// - group
// - domainID
// - token
func (_e *SDK_Expecter) UpdateGroupTags(ctx interface{}, group interface{}, domainID interface{}, token interface{}) *SDK_UpdateGroupTags_Call {
return &SDK_UpdateGroupTags_Call{Call: _e.mock.On("UpdateGroupTags", ctx, group, domainID, token)}
}
func (_c *SDK_UpdateGroupTags_Call) Run(run func(ctx context.Context, group sdk.Group, domainID string, token string)) *SDK_UpdateGroupTags_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(sdk.Group), args[2].(string), args[3].(string))
})
return _c
}
func (_c *SDK_UpdateGroupTags_Call) Return(group1 sdk.Group, sDKError errors.SDKError) *SDK_UpdateGroupTags_Call {
_c.Call.Return(group1, sDKError)
return _c
}
func (_c *SDK_UpdateGroupTags_Call) RunAndReturn(run func(ctx context.Context, group sdk.Group, domainID string, token string) (sdk.Group, errors.SDKError)) *SDK_UpdateGroupTags_Call {
_c.Call.Return(run)
return _c
}
// UpdatePassword provides a mock function for the type SDK
func (_mock *SDK) UpdatePassword(ctx context.Context, oldPass string, newPass string, token string) (sdk.User, errors.SDKError) {
ret := _mock.Called(ctx, oldPass, newPass, token)
+11
View File
@@ -640,6 +640,17 @@ type SDK interface {
// fmt.Println(group)
UpdateGroup(ctx context.Context, group Group, domainID, token string) (Group, errors.SDKError)
// UpdateGroupTags updates tags for existing group.
//
// example:
// group := sdk.Group{
// ID: "groupID",
// Tags: []string{"tag1", "tag2"}
// }
// group, _ := sdk.UpdateGroupTags(group, "domainID", "token")
// fmt.Println(group)
UpdateGroupTags(ctx context.Context, group Group, domainID, token string) (Group, errors.SDKError)
// SetGroupParent sets the parent group of a group.
//
// example:
+2 -1
View File
@@ -113,6 +113,7 @@ func convertGroup(g sdk.Group) groups.Group {
Parent: g.ParentID,
Name: g.Name,
Description: g.Description,
Tags: g.Tags,
Metadata: groups.Metadata(g.Metadata),
Level: g.Level,
Path: g.Path,
@@ -133,7 +134,7 @@ func convertGroup(g sdk.Group) groups.Group {
}
func convertChildren(gs []*sdk.Group) []*groups.Group {
cg := []*groups.Group{}
var cg []*groups.Group
if len(gs) == 0 {
return cg