Files
magistrala/groups/service_test.go
T
Dušan Borovčanin 61d0427898 NOISSUE - Rename to Magistrala (#3427)
Signed-off-by: dusan <borovcanindusan1@gmail.com>
2026-04-06 15:23:42 +02:00

1325 lines
40 KiB
Go

// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/0x6flab/namegenerator"
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
grpcClientsV1 "github.com/absmach/magistrala/api/grpc/clients/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
chmocks "github.com/absmach/magistrala/channels/mocks"
climocks "github.com/absmach/magistrala/clients/mocks"
"github.com/absmach/magistrala/groups"
"github.com/absmach/magistrala/groups/mocks"
"github.com/absmach/magistrala/internal/nullable"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/authn"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
policysvc "github.com/absmach/magistrala/pkg/policies"
policymocks "github.com/absmach/magistrala/pkg/policies/mocks"
"github.com/absmach/magistrala/pkg/roles"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
idProvider = uuid.New()
namegen = namegenerator.NewGenerator()
description = namegen.Generate()
desc = nullable.Value[string]{Valid: true, Value: description}
validGroup = groups.Group{
ID: testsutil.GenerateUUID(&testing.T{}),
Name: namegen.Generate(),
Description: desc,
Metadata: map[string]any{
"key": "value",
},
Status: groups.EnabledStatus,
}
validGroupWithRoles = groups.Group{
ID: testsutil.GenerateUUID(&testing.T{}),
Name: namegen.Generate(),
Description: desc,
Metadata: map[string]any{
"key": "value",
},
Status: groups.EnabledStatus,
Roles: []roles.MemberRoleActions{
{
RoleID: "test-id",
RoleName: "test-name",
AccessType: "direct",
},
},
}
parentGroupID = testsutil.GenerateUUID(&testing.T{})
childGroupID = testsutil.GenerateUUID(&testing.T{})
childGroup = groups.Group{
ID: childGroupID,
Name: namegen.Generate(),
Description: desc,
Metadata: map[string]any{
"key": "value",
},
Status: groups.EnabledStatus,
Parent: parentGroupID,
}
children = []*groups.Group{&childGroup}
parentGroup = groups.Group{
ID: parentGroupID,
Name: namegen.Generate(),
Description: desc,
Metadata: map[string]any{
"key": "value",
},
Status: groups.EnabledStatus,
Children: children,
}
validID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
)
var (
repo *mocks.Repository
policies *policymocks.Service
channels *chmocks.ChannelsServiceClient
clients *climocks.ClientsServiceClient
)
func newService(t *testing.T) groups.Service {
repo = new(mocks.Repository)
policies = new(policymocks.Service)
channels = new(chmocks.ChannelsServiceClient)
clients = new(climocks.ClientsServiceClient)
availableActions := []roles.Action{}
builtInRoles := map[roles.BuiltInRoleName][]roles.Action{
groups.BuiltInRoleAdmin: availableActions,
}
svc, err := groups.NewService(repo, policies, idProvider, channels, clients, idProvider, availableActions, builtInRoles)
assert.Nil(t, err, fmt.Sprintf(" Unexpected error while creating service %v", err))
return svc
}
func TestCreateGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
group groups.Group
saveResp groups.Group
saveErr error
deleteErr error
addPoliciesErr error
deletePoliciesErr error
addRoleErr error
err error
}{
{
desc: "create group successfully",
group: validGroup,
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: validID,
},
err: nil,
},
{
desc: "create group with invalid status",
group: groups.Group{
Name: namegen.Generate(),
Description: desc,
Status: groups.Status(100),
},
err: svcerr.ErrInvalidStatus,
},
{
desc: "create group successfully with parent",
group: groups.Group{
Name: namegen.Generate(),
Description: desc,
Status: groups.EnabledStatus,
Parent: testsutil.GenerateUUID(t),
},
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: testsutil.GenerateUUID(t),
Parent: testsutil.GenerateUUID(t),
},
err: nil,
},
{
desc: "create group with failed to save",
group: validGroup,
saveResp: groups.Group{},
saveErr: errors.ErrMalformedEntity,
err: errors.ErrMalformedEntity,
},
{
desc: " create group with failed to add policies",
group: validGroup,
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: validID,
},
addPoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrAddPolicies,
},
{
desc: " create group with failed to add policies and failed rollback",
group: validGroup,
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: validID,
},
addPoliciesErr: svcerr.ErrAuthorization,
deleteErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "create group with failed to add roles",
group: validGroup,
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: validID,
},
addRoleErr: svcerr.ErrCreateEntity,
err: svcerr.ErrAddPolicies,
},
{
desc: "create groups with failed to add roles and failed to delete policies",
group: validGroup,
saveResp: groups.Group{
ID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
Domain: validID,
},
addRoleErr: svcerr.ErrCreateEntity,
deletePoliciesErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrAddPolicies,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("Save", context.Background(), mock.Anything).Return(tc.saveResp, tc.saveErr)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr)
got, _, err := svc.CreateGroup(context.Background(), validSession, tc.group)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v but got %v", tc.err, err))
if err == nil {
assert.NotEmpty(t, got.ID)
assert.NotEmpty(t, got.CreatedAt)
assert.NotEmpty(t, got.Domain)
assert.WithinDuration(t, time.Now(), got.CreatedAt, 2*time.Second)
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
func TestViewGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
session smqauthn.Session
id string
withRoles bool
repoResp groups.Group
repoErr error
err error
}{
{
desc: "view group successfully",
id: validGroup.ID,
session: validSession,
withRoles: false,
repoResp: validGroup,
},
{
desc: "view group successfully with roles",
id: validGroupWithRoles.ID,
session: validSession,
withRoles: true,
repoResp: validGroupWithRoles,
},
{
desc: "view group with failed to retrieve",
id: testsutil.GenerateUUID(t),
session: validSession,
withRoles: false,
repoErr: repoerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.repoResp, tc.repoErr)
repoCall1 := repo.On("RetrieveByIDWithRoles", context.Background(), tc.id, tc.session.UserID).Return(tc.repoResp, tc.repoErr)
got, err := svc.ViewGroup(context.Background(), validSession, tc.id, tc.withRoles)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
switch tc.withRoles {
case true:
assert.Equal(t, tc.repoResp, got)
ok := repo.AssertCalled(t, "RetrieveByIDWithRoles", context.Background(), tc.id, tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByIDWithRoles was not called on %s", tc.desc))
default:
assert.Equal(t, tc.repoResp, got)
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestUpdateGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
group groups.Group
repoResp groups.Group
repoErr error
err error
}{
{
desc: "update group successfully",
group: groups.Group{
ID: testsutil.GenerateUUID(t),
Name: namegen.Generate(),
},
repoResp: validGroup,
},
{
desc: "update group with repo error",
group: groups.Group{
ID: testsutil.GenerateUUID(t),
Name: namegen.Generate(),
},
repoErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("Update", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
got, err := svc.UpdateGroup(context.Background(), validSession, tc.group)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
assert.Equal(t, tc.repoResp, got)
ok := repo.AssertCalled(t, "Update", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
})
}
}
func TestUpdateGroupTags(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
updateReq groups.Group
repoResp groups.Group
repoErr error
err error
}{
{
desc: "update group tags successfully",
updateReq: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
repoResp: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
},
{
desc: "update group tags with repo error",
updateReq: groups.Group{
ID: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
},
repoErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("UpdateTags", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
got, err := svc.UpdateGroupTags(context.Background(), validSession, tc.updateReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
assert.Equal(t, tc.repoResp, got)
ok := repo.AssertCalled(t, "UpdateTags", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateTags was not called on %s", tc.desc))
}
repoCall.Unset()
})
}
}
func TestEnableGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
retrieveResp groups.Group
retrieveErr error
changeResp groups.Group
changeErr error
err error
}{
{
desc: "enable group successfully",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{
Status: groups.DisabledStatus,
},
changeResp: validGroup,
},
{
desc: "enable group with enabled group",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{
Status: groups.EnabledStatus,
},
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable group with retrieve error",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{},
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeResp, tc.changeErr)
got, err := svc.EnableGroup(context.Background(), validSession, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
assert.Equal(t, tc.changeResp, got)
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestDisableGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
retrieveResp groups.Group
retrieveErr error
changeResp groups.Group
changeErr error
err error
}{
{
desc: "disable group successfully",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{
Status: groups.EnabledStatus,
},
changeResp: validGroup,
},
{
desc: "disable group with disabled group",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{
Status: groups.DisabledStatus,
},
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable group with retrieve error",
id: testsutil.GenerateUUID(t),
retrieveResp: groups.Group{},
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeResp, tc.changeErr)
got, err := svc.DisableGroup(context.Background(), validSession, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
assert.Equal(t, tc.changeResp, got)
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestListGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
session smqauthn.Session
pageMeta groups.PageMeta
retrieveAllRes groups.Page
retrieveAllErr error
retrieveUserGroupRes groups.Page
retrieveUserGroupErr error
resp groups.Page
err error
}{
{
desc: "list groups as super admin successfully",
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
DomainID: validID,
},
retrieveAllRes: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
resp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "list groups as super admin with failed to retrieve",
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true},
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
DomainID: validID,
},
retrieveAllErr: repoerr.ErrNotFound,
resp: groups.Page{},
err: repoerr.ErrNotFound,
},
{
desc: "list groups as non admin successfully",
session: validSession,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
retrieveUserGroupRes: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
resp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "list groups as non admin with failed to retrieve user groups",
session: validSession,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
retrieveUserGroupErr: repoerr.ErrNotFound,
resp: groups.Page{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.retrieveAllRes, tc.retrieveAllErr)
repoCall1 := repo.On("RetrieveUserGroups", context.Background(), tc.session.DomainID, tc.session.UserID, tc.pageMeta).Return(tc.retrieveUserGroupRes, tc.retrieveUserGroupErr)
got, err := svc.ListGroups(context.Background(), tc.session, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
assert.Equal(t, tc.resp, got)
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestListUserGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
session smqauthn.Session
userID string
pageMeta groups.PageMeta
retrieveUserGroupRes groups.Page
retrieveUserGroupErr error
resp groups.Page
err error
}{
{
desc: "list user groups successfully",
session: validSession,
userID: validID,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
retrieveUserGroupRes: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
resp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "list user groups with failed to retrieve",
session: validSession,
userID: validID,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
retrieveUserGroupErr: repoerr.ErrNotFound,
resp: groups.Page{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveUserGroups", context.Background(), tc.session.DomainID, tc.userID, tc.pageMeta).Return(tc.retrieveUserGroupRes, tc.retrieveUserGroupErr)
got, err := svc.ListUserGroups(context.Background(), tc.session, tc.userID, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
assert.Equal(t, tc.resp, got)
repoCall.Unset()
})
}
}
func TestRetrieveGroupHierarchy(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
pageMeta groups.HierarchyPageMeta
retrieveHierarchyRes groups.HierarchyPage
retrieveHierarchyErr error
err error
}{
{
desc: "retrieve group hierarchy successfully",
id: parentGroup.ID,
pageMeta: groups.HierarchyPageMeta{
Level: 1,
Direction: -1,
Tree: false,
},
retrieveHierarchyRes: groups.HierarchyPage{
HierarchyPageMeta: groups.HierarchyPageMeta{
Level: 1,
Direction: -1,
Tree: false,
},
Groups: []groups.Group{parentGroup},
},
err: nil,
},
{
desc: "retrieve group hierarchy with failed to retrieve hierarchy",
id: parentGroup.ID,
pageMeta: groups.HierarchyPageMeta{
Level: 1,
Direction: -1,
Tree: false,
},
retrieveHierarchyErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "retrieve group hierarchy with invalid group ID",
id: testsutil.GenerateUUID(t),
pageMeta: groups.HierarchyPageMeta{
Level: 1,
Direction: -1,
Tree: false,
},
retrieveHierarchyErr: nil,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveHierarchy", context.Background(), validSession.DomainID, validSession.UserID, tc.id, tc.pageMeta).Return(tc.retrieveHierarchyRes, tc.retrieveHierarchyErr)
_, err := svc.RetrieveGroupHierarchy(context.Background(), validSession, tc.id, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if tc.err == nil {
ok := repo.AssertCalled(t, "RetrieveHierarchy", context.Background(), validSession.DomainID, validSession.UserID, tc.id, tc.pageMeta)
assert.True(t, ok, fmt.Sprintf("RetrieveHierarchy was not called on %s", tc.desc))
}
repoCall.Unset()
})
}
}
func TestAddParentGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
parentID string
retrieveResp groups.Group
retrieveErr error
addPoliciesErr error
deletePoliciesErr error
assignParentErr error
err error
}{
{
desc: "add parent group successfully",
id: validGroup.ID,
parentID: parentGroupID,
retrieveResp: validGroup,
err: nil,
},
{
desc: "add parent group with failed to retrieve",
id: validGroup.ID,
parentID: parentGroupID,
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "add parent group to group with parent",
id: childGroupID,
parentID: parentGroupID,
retrieveResp: childGroup,
err: svcerr.ErrConflict,
},
{
desc: "add parent group with failed to add policies",
id: validGroup.ID,
parentID: parentGroupID,
retrieveResp: validGroup,
addPoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrAddPolicies,
},
{
desc: "add parent group with repo error in assign parent group",
id: validGroup.ID,
parentID: parentGroupID,
retrieveResp: validGroup,
assignParentErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "add parent group with repo error in assign parent group and failed to delete policies",
id: validGroup.ID,
parentID: parentGroupID,
retrieveResp: validGroup,
assignParentErr: repoerr.ErrNotFound,
deletePoliciesErr: svcerr.ErrAuthorization,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pol := policysvc.Policy{
Domain: validID,
SubjectType: policysvc.GroupType,
Subject: tc.parentID,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.GroupType,
Object: tc.id,
}
repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr)
policyCall := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr)
repoCall1 := repo.On("AssignParentGroup", context.Background(), tc.parentID, []string{tc.id}).Return(tc.assignParentErr)
err := svc.AddParentGroup(context.Background(), validSession, tc.id, tc.parentID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
})
}
}
func TestRemoveParentGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
retrieveResp groups.Group
retrieveErr error
deletePoliciesErr error
addPoliciesErr error
unassignParentErr error
err error
}{
{
desc: "remove parent group successfully",
id: childGroupID,
retrieveResp: childGroup,
err: nil,
},
{
desc: "remove parent group with failed to retrieve",
id: childGroupID,
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "remove parent group with no parent",
id: validGroup.ID,
retrieveResp: validGroup,
err: nil,
},
{
desc: "remove parent group with failed to delete policies",
id: childGroupID,
retrieveResp: childGroup,
deletePoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrDeletePolicies,
},
{
desc: "remove parent group with repo error in unassign parent group",
id: childGroupID,
retrieveResp: childGroup,
unassignParentErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "remove parent group with repo error in unassign parent group and failed to add policies",
id: childGroupID,
retrieveResp: childGroup,
unassignParentErr: repoerr.ErrNotFound,
addPoliciesErr: svcerr.ErrAuthorization,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pol := policysvc.Policy{
Domain: validID,
SubjectType: policysvc.GroupType,
Subject: tc.retrieveResp.Parent,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.GroupType,
Object: tc.id,
}
repoCall := repo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveResp, tc.retrieveErr)
policyCall := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr)
policyCall1 := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr)
repoCall1 := repo.On("UnassignParentGroup", context.Background(), tc.retrieveResp.Parent, []string{tc.id}).Return(tc.unassignParentErr)
err := svc.RemoveParentGroup(context.Background(), validSession, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
})
}
}
func TestAddChildrenGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
parentID string
childrenIDs []string
retrieveResp groups.Page
retrieveErr error
addPoliciesErr error
deletePoliciesErr error
assignParentErr error
err error
}{
{
desc: "add children groups successfully",
parentID: parentGroupID,
childrenIDs: []string{validGroup.ID},
retrieveResp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "add children groups with failed to retrieve",
parentID: parentGroupID,
childrenIDs: []string{validGroup.ID},
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "add non existent child group",
parentID: parentGroupID,
childrenIDs: []string{testsutil.GenerateUUID(&testing.T{})},
retrieveResp: groups.Page{},
err: groups.ErrGroupIDs,
},
{
desc: "add child group with parent",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: svcerr.ErrConflict,
},
{
desc: "add children groups with failed to add policies",
parentID: parentGroupID,
childrenIDs: []string{validGroup.ID},
retrieveResp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
addPoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrAddPolicies,
},
{
desc: "add children groups with repo error in assign children groups",
parentID: parentGroupID,
childrenIDs: []string{validGroup.ID},
retrieveResp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
assignParentErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "add children groups with repo error in assign children groups and failed to delete policies",
parentID: parentGroupID,
childrenIDs: []string{validGroup.ID},
retrieveResp: groups.Page{
Groups: []groups.Group{validGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
assignParentErr: repoerr.ErrNotFound,
deletePoliciesErr: svcerr.ErrAuthorization,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pol := policysvc.Policy{
Domain: validID,
SubjectType: policysvc.GroupType,
Subject: tc.parentID,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.GroupType,
Object: validGroup.ID,
}
repoCall := repo.On("RetrieveByIDs", context.Background(), groups.PageMeta{Limit: 1<<63 - 1}, tc.childrenIDs).Return(tc.retrieveResp, tc.retrieveErr)
policyCall := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr)
repoCall1 := repo.On("AssignParentGroup", context.Background(), tc.parentID, tc.childrenIDs).Return(tc.assignParentErr)
err := svc.AddChildrenGroups(context.Background(), validSession, tc.parentID, tc.childrenIDs)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
})
}
}
func TestRemoveChildrenGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
parentID string
childrenIDs []string
retrieveResp groups.Page
retrieveErr error
deletePoliciesErr error
addPoliciesErr error
unassignParentErr error
err error
}{
{
desc: "remove children groups successfully",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "remove children groups with failed to retrieve",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "remove non existent child group",
parentID: parentGroupID,
childrenIDs: []string{testsutil.GenerateUUID(&testing.T{})},
retrieveResp: groups.Page{},
err: groups.ErrGroupIDs,
},
{
desc: "remove children groups from different parent",
parentID: validGroup.ID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: svcerr.ErrConflict,
},
{
desc: "remove children groups with failed to delete policies",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
deletePoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrDeletePolicies,
},
{
desc: "remove children groups with repo error in unassign children groups",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
unassignParentErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "remove children groups with repo error in unassign children groups and failed to add policies",
parentID: parentGroupID,
childrenIDs: []string{childGroupID},
retrieveResp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
unassignParentErr: repoerr.ErrNotFound,
addPoliciesErr: svcerr.ErrAuthorization,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pol := policysvc.Policy{
Domain: validID,
SubjectType: policysvc.GroupType,
Subject: tc.parentID,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.GroupType,
Object: childGroupID,
}
repoCall := repo.On("RetrieveByIDs", context.Background(), groups.PageMeta{Limit: 1<<63 - 1}, tc.childrenIDs).Return(tc.retrieveResp, tc.retrieveErr)
policyCall := policies.On("DeletePolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.deletePoliciesErr)
policyCall1 := policies.On("AddPolicies", context.Background(), []policysvc.Policy{pol}).Return(tc.addPoliciesErr)
repoCall1 := repo.On("UnassignParentGroup", context.Background(), tc.parentID, tc.childrenIDs).Return(tc.unassignParentErr)
err := svc.RemoveChildrenGroups(context.Background(), validSession, tc.parentID, tc.childrenIDs)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
})
}
}
func TestRemoveAllChildrenGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
parentID string
deletePolicyErr error
unassignAllChildrenErr error
err error
}{
{
desc: "remove all children groups successfully",
parentID: parentGroupID,
err: nil,
},
{
desc: "remove all children groups with failed to delete policy",
parentID: parentGroupID,
deletePolicyErr: svcerr.ErrAuthorization,
err: svcerr.ErrDeletePolicies,
},
{
desc: "remove all children groups with failed to unassign all children",
parentID: parentGroupID,
deletePolicyErr: nil,
unassignAllChildrenErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
policyCall := policies.On("DeletePolicyFilter", context.Background(), policysvc.Policy{
Domain: validID,
SubjectType: policysvc.GroupType,
Subject: tc.parentID,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.GroupType,
}).Return(tc.deletePolicyErr)
repoCall := repo.On("UnassignAllChildrenGroups", context.Background(), tc.parentID).Return(tc.unassignAllChildrenErr)
err := svc.RemoveAllChildrenGroups(context.Background(), validSession, tc.parentID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
policyCall.Unset()
repoCall.Unset()
})
}
}
func TestListAllChildrenGroups(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
session smqauthn.Session
pageMeta groups.PageMeta
parentID string
startLevel int64
endLevel int64
retrieveRes groups.Page
retrieveErr error
resp groups.Page
err error
}{
{
desc: "list all children groups successfully",
session: validSession,
parentID: parentGroupID,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
startLevel: 0,
endLevel: -1,
retrieveRes: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
resp: groups.Page{
Groups: []groups.Group{childGroup},
PageMeta: groups.PageMeta{
Total: 1,
},
},
err: nil,
},
{
desc: "list all children groups with failed to retrieve",
session: validSession,
parentID: parentGroupID,
pageMeta: groups.PageMeta{
Limit: 10,
Offset: 0,
},
startLevel: 0,
endLevel: -1,
retrieveErr: repoerr.ErrNotFound,
resp: groups.Page{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveChildrenGroups", context.Background(), tc.session.DomainID, tc.session.UserID, tc.parentID, tc.startLevel, tc.endLevel, tc.pageMeta).Return(tc.retrieveRes, tc.retrieveErr)
page, err := svc.ListChildrenGroups(context.Background(), tc.session, tc.parentID, tc.startLevel, tc.endLevel, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
assert.Equal(t, tc.resp, page)
repoCall.Unset()
})
}
}
func TestDeleteGroup(t *testing.T) {
svc := newService(t)
cases := []struct {
desc string
id string
changeStatusRes groups.Group
changeStatusErr error
deletePoliciesErr error
deleteErr error
unsetFromChannels error
unsetFromClients error
err error
}{
{
desc: "delete group successfully",
id: validGroup.ID,
err: nil,
},
{
desc: "delete group with parent successfully",
id: childGroupID,
changeStatusRes: childGroup,
err: nil,
},
{
desc: "delete group with failed to remove parent group from channels",
id: validGroup.ID,
unsetFromChannels: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "delete group with failed to remove parent group from clients",
id: validGroup.ID,
unsetFromChannels: nil,
unsetFromClients: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "delete group with failed to change status",
id: validGroup.ID,
changeStatusErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "delete group with failed to delete",
id: validGroup.ID,
changeStatusRes: validGroup,
deleteErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "delete group with failed to delete policies",
id: validGroup.ID,
changeStatusRes: validGroup,
deleteErr: nil,
deletePoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrDeletePolicies,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("ChangeStatus", context.Background(), groups.Group{ID: tc.id, Status: groups.DeletedStatus}).Return(tc.changeStatusRes, tc.changeStatusErr)
repoCall1 := repo.On("Delete", context.Background(), tc.id).Return(tc.deleteErr)
svcCall := channels.On("UnsetParentGroupFromChannels", context.Background(), &grpcChannelsV1.UnsetParentGroupFromChannelsReq{ParentGroupId: tc.id}).Return(&grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, tc.unsetFromChannels)
svcCall1 := clients.On("UnsetParentGroupFromClient", context.Background(), &grpcClientsV1.UnsetParentGroupFromClientReq{ParentGroupId: tc.id}).Return(&grpcClientsV1.UnsetParentGroupFromClientRes{}, tc.unsetFromClients)
repoCall2 := repo.On("RetrieveEntitiesRolesActionsMembers", context.Background(), []string{tc.id}).Return([]roles.EntityActionRole{}, []roles.EntityMemberRole{}, nil)
policyCall := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(nil)
err := svc.DeleteGroup(context.Background(), validSession, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
policyCall.Unset()
repoCall.Unset()
repoCall1.Unset()
svcCall.Unset()
svcCall1.Unset()
repoCall2.Unset()
policyCall1.Unset()
})
}
}