Files
Dušan Borovčanin 61d0427898 NOISSUE - Rename to Magistrala (#3427)
Signed-off-by: dusan <borovcanindusan1@gmail.com>
2026-04-06 15:23:42 +02:00

1286 lines
43 KiB
Go

// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package clients_test
import (
"context"
"fmt"
"testing"
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
chmocks "github.com/absmach/magistrala/channels/mocks"
"github.com/absmach/magistrala/clients"
climocks "github.com/absmach/magistrala/clients/mocks"
gpmocks "github.com/absmach/magistrala/groups/mocks"
"github.com/absmach/magistrala/internal/testsutil"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
policysvc "github.com/absmach/magistrala/pkg/policies"
policymocks "github.com/absmach/magistrala/pkg/policies/mocks"
"github.com/absmach/magistrala/pkg/roles"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
secret = "strongsecret"
validMetadata = clients.Metadata{"role": "client"}
ID = "6e5e10b3-d4df-4758-b426-4929d55ad740"
client = clients.Client{
ID: ID,
Name: "clientname",
Tags: []string{"tag1", "tag2"},
Credentials: clients.Credentials{Identity: "clientidentity", Secret: secret},
PrivateMetadata: validMetadata,
Metadata: validMetadata,
Status: clients.EnabledStatus,
}
clientWithRoles = clients.Client{
ID: ID,
Name: "clientname",
Tags: []string{"tag1", "tag2"},
Credentials: clients.Credentials{Identity: "clientidentity", Secret: secret},
PrivateMetadata: validMetadata,
Metadata: validMetadata,
Status: clients.EnabledStatus,
Roles: []roles.MemberRoleActions{
{
RoleID: "test_role_id",
RoleName: "test_role_name",
},
},
}
validToken = "token"
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
wrongID = testsutil.GenerateUUID(&testing.T{})
)
var (
pService *policymocks.Service
cache *climocks.Cache
repo *climocks.Repository
chgRPCClient *chmocks.ChannelsServiceClient
gpgRPCClient *gpmocks.GroupsServiceClient
)
func newService() clients.Service {
pService = new(policymocks.Service)
cache = new(climocks.Cache)
idProvider := uuid.NewMock()
sidProvider := uuid.NewMock()
repo = new(climocks.Repository)
chgRPCClient = new(chmocks.ChannelsServiceClient)
gpgRPCClient = new(gpmocks.GroupsServiceClient)
availableActions := []roles.Action{}
builtInRoles := map[roles.BuiltInRoleName][]roles.Action{
clients.BuiltInRoleAdmin: availableActions,
}
tsv, _ := clients.NewService(repo, pService, cache, chgRPCClient, gpgRPCClient, idProvider, sidProvider, availableActions, builtInRoles)
return tsv
}
func TestCreateClients(t *testing.T) {
svc := newService()
cases := []struct {
desc string
client clients.Client
token string
addPolicyErr error
deletePolicyErr error
saveErr error
addRoleErr error
deleteErr error
err error
}{
{
desc: "create a new client successfully",
client: client,
token: validToken,
err: nil,
},
{
desc: "create an existing client",
client: client,
token: validToken,
saveErr: repoerr.ErrConflict,
err: repoerr.ErrConflict,
},
{
desc: "create a new client without secret",
client: clients.Client{
Name: "clientWithoutSecret",
Credentials: clients.Credentials{
Identity: "newclientwithoutsecret@example.com",
},
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new client without identity",
client: clients.Client{
Name: "clientWithoutIdentity",
Credentials: clients.Credentials{
Identity: "newclientwithoutsecret@example.com",
},
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new enabled client with name",
client: clients.Client{
Name: "clientWithName",
Credentials: clients.Credentials{
Identity: "newclientwithname@example.com",
Secret: secret,
},
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new disabled client with name",
client: clients.Client{
Name: "clientWithName",
Credentials: clients.Credentials{
Identity: "newclientwithname@example.com",
Secret: secret,
},
},
token: validToken,
err: nil,
},
{
desc: "create a new enabled client with tags",
client: clients.Client{
Tags: []string{"tag1", "tag2"},
Credentials: clients.Credentials{
Identity: "newclientwithtags@example.com",
Secret: secret,
},
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new disabled client with tags",
client: clients.Client{
Tags: []string{"tag1", "tag2"},
Credentials: clients.Credentials{
Identity: "newclientwithtags@example.com",
Secret: secret,
},
Status: clients.DisabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new enabled client with private metadata",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithmetadata@example.com",
Secret: secret,
},
PrivateMetadata: validMetadata,
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new enabled client with metadata",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithmetadata@example.com",
Secret: secret,
},
Metadata: validMetadata,
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new disabled client with private metadata",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithmetadata@example.com",
Secret: secret,
},
PrivateMetadata: validMetadata,
},
token: validToken,
err: nil,
},
{
desc: "create a new disabled client",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithvalidstatus@example.com",
Secret: secret,
},
},
token: validToken,
err: nil,
},
{
desc: "create a new client with valid disabled status",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithvalidstatus@example.com",
Secret: secret,
},
Status: clients.DisabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new client with all fields",
client: clients.Client{
Name: "newclientwithallfields",
Tags: []string{"tag1", "tag2"},
Credentials: clients.Credentials{
Identity: "newclientwithallfields@example.com",
Secret: secret,
},
PrivateMetadata: clients.Metadata{
"name": "newclientwithallfields",
},
Metadata: clients.Metadata{
"name": "newclientwithallfields",
},
Status: clients.EnabledStatus,
},
token: validToken,
err: nil,
},
{
desc: "create a new client with invalid status",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithinvalidstatus@example.com",
Secret: secret,
},
Status: clients.AllStatus,
},
token: validToken,
err: svcerr.ErrInvalidStatus,
},
{
desc: "create a new client with failed add policies response",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithfailedpolicy@example.com",
Secret: secret,
},
Status: clients.EnabledStatus,
},
token: validToken,
addPolicyErr: svcerr.ErrInvalidPolicy,
err: svcerr.ErrInvalidPolicy,
},
{
desc: "create a new client with failed delete policies response",
client: clients.Client{
Credentials: clients.Credentials{
Identity: "newclientwithfailedpolicy@example.com",
Secret: secret,
},
Status: clients.EnabledStatus,
},
token: validToken,
saveErr: repoerr.ErrConflict,
deletePolicyErr: svcerr.ErrInvalidPolicy,
err: repoerr.ErrConflict,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("Save", context.Background(), mock.Anything).Return([]clients.Client{tc.client}, tc.saveErr)
policyCall := pService.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPolicyErr)
policyCall1 := pService.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr)
expected, _, err := svc.CreateClients(context.Background(), smqauthn.Session{}, tc.client)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
tc.client.ID = expected[0].ID
tc.client.CreatedAt = expected[0].CreatedAt
tc.client.UpdatedAt = expected[0].UpdatedAt
tc.client.Credentials.Secret = expected[0].Credentials.Secret
tc.client.Domain = expected[0].Domain
tc.client.UpdatedBy = expected[0].UpdatedBy
assert.Equal(t, tc.client, expected[0], fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.client, expected[0]))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
func TestViewClient(t *testing.T) {
svc := newService()
cases := []struct {
desc string
clientID string
withRoles bool
response clients.Client
retrieveErr error
err error
}{
{
desc: "view client successfully",
response: client,
withRoles: false,
clientID: client.ID,
err: nil,
},
{
desc: "view client successfully with roles",
response: clientWithRoles,
withRoles: true,
clientID: clientWithRoles.ID,
err: nil,
},
{
desc: "view client with an invalid token",
response: clients.Client{},
withRoles: false,
clientID: "",
err: svcerr.ErrAuthorization,
},
{
desc: "view client with valid token and invalid client id",
response: clients.Client{},
withRoles: false,
clientID: wrongID,
retrieveErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "view client with an invalid token and invalid client id",
response: clients.Client{},
withRoles: false,
clientID: wrongID,
err: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), tc.clientID).Return(tc.response, tc.err)
repoCall1 := repo.On("RetrieveByIDWithRoles", context.Background(), tc.clientID, mock.Anything).Return(tc.response, tc.err)
rClient, err := svc.View(context.Background(), smqauthn.Session{}, tc.clientID, tc.withRoles)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
switch tc.withRoles {
case true:
assert.NotEmpty(t, rClient.Roles)
ok := repo.AssertCalled(t, "RetrieveByIDWithRoles", context.Background(), tc.clientID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveByIDWithRoles was not called on %s", tc.desc))
default:
ok := repo.AssertCalled(t, "RetrieveByID", context.Background(), tc.clientID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
assert.Equal(t, tc.response, rClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rClient))
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestListClients(t *testing.T) {
svc := newService()
adminID := testsutil.GenerateUUID(t)
domainID := testsutil.GenerateUUID(t)
nonAdminID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
userKind string
session smqauthn.Session
page clients.Page
listObjectsResponse policysvc.PolicyPage
retrieveAllResponse clients.ClientsPage
listPermissionsResponse policysvc.Permissions
response clients.ClientsPage
id string
size uint64
listObjectsErr error
retrieveAllErr error
listPermissionsErr error
err error
}{
{
desc: "list all clients successfully as non admin",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
err: nil,
},
{
desc: "list all clients as non admin with failed to retrieve all",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{},
response: clients.ClientsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as non admin with failed super admin",
userKind: "non-admin",
session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false},
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
},
response: clients.ClientsPage{},
listObjectsResponse: policysvc.PolicyPage{},
err: nil,
},
{
desc: "list all clients as non admin with failed to list objects",
userKind: "non-admin",
id: nonAdminID,
page: clients.Page{
Offset: 0,
Limit: 100,
},
retrieveAllErr: repoerr.ErrNotFound,
response: clients.ClientsPage{},
listObjectsResponse: policysvc.PolicyPage{},
listObjectsErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
retrieveUserClientsCall := repo.On("RetrieveUserClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
retrieveAllCall.Unset()
retrieveUserClientsCall.Unset()
})
}
cases2 := []struct {
desc string
userKind string
session smqauthn.Session
page clients.Page
listObjectsResponse policysvc.PolicyPage
retrieveAllResponse clients.ClientsPage
listPermissionsResponse policysvc.Permissions
response clients.ClientsPage
id string
size uint64
listObjectsErr error
retrieveAllErr error
listPermissionsErr error
err error
}{
{
desc: "list all clients as admin successfully",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}},
retrieveAllResponse: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
response: clients.ClientsPage{
Page: clients.Page{
Total: 2,
Offset: 0,
Limit: 100,
},
Clients: []clients.Client{client, client},
},
err: nil,
},
{
desc: "list all clients as admin with failed to retrieve all",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
listObjectsResponse: policysvc.PolicyPage{},
retrieveAllResponse: clients.ClientsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "list all clients as admin with failed to list clients",
userKind: "admin",
id: adminID,
session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true},
page: clients.Page{
Offset: 0,
Limit: 100,
Domain: domainID,
},
retrieveAllResponse: clients.ClientsPage{},
retrieveAllErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases2 {
t.Run(tc.desc, func(t *testing.T) {
retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListClients(context.Background(), tc.session, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
retrieveAllCall.Unset()
})
}
}
func TestUpdateClient(t *testing.T) {
svc := newService()
client1 := client
client2 := client
client1.Name = "Updated client"
client2.PrivateMetadata = clients.Metadata{"role": "test"}
client2.Metadata = clients.Metadata{"role": "test"}
cases := []struct {
desc string
client clients.Client
session smqauthn.Session
updateResponse clients.Client
updateErr error
err error
}{
{
desc: "update client name successfully",
client: client1,
session: smqauthn.Session{UserID: validID},
updateResponse: client1,
err: nil,
},
{
desc: "update client metadata with valid token",
client: client2,
updateResponse: client2,
session: smqauthn.Session{UserID: validID},
err: nil,
},
{
desc: "update client with failed to update repo",
client: client1,
updateResponse: clients.Client{},
session: smqauthn.Session{UserID: validID},
updateErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall1 := repo.On("Update", context.Background(), mock.Anything).Return(tc.updateResponse, tc.updateErr)
updatedClient, err := svc.Update(context.Background(), tc.session, tc.client)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateResponse, updatedClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedClient))
repoCall1.Unset()
})
}
}
func TestUpdateTags(t *testing.T) {
svc := newService()
client.Tags = []string{"updated"}
cases := []struct {
desc string
client clients.Client
session smqauthn.Session
updateResponse clients.Client
updateErr error
err error
}{
{
desc: "update client tags successfully",
client: client,
session: smqauthn.Session{UserID: validID},
updateResponse: client,
err: nil,
},
{
desc: "update client tags with failed to update repo",
client: client,
updateResponse: clients.Client{},
session: smqauthn.Session{UserID: validID},
updateErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall1 := repo.On("UpdateTags", context.Background(), mock.Anything).Return(tc.updateResponse, tc.updateErr)
updatedClient, err := svc.UpdateTags(context.Background(), tc.session, tc.client)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateResponse, updatedClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedClient))
repoCall1.Unset()
})
}
}
func TestUpdateSecret(t *testing.T) {
svc := newService()
cases := []struct {
desc string
client clients.Client
newSecret string
updateSecretResponse clients.Client
session smqauthn.Session
updateErr error
removeErr error
err error
}{
{
desc: "update client secret successfully",
client: client,
newSecret: "newSecret",
session: smqauthn.Session{UserID: validID},
updateSecretResponse: clients.Client{
ID: client.ID,
Credentials: clients.Credentials{
Identity: client.Credentials.Identity,
Secret: "newSecret",
},
},
err: nil,
},
{
desc: "update client secret with failed to update repo",
client: client,
newSecret: "newSecret",
session: smqauthn.Session{UserID: validID},
updateSecretResponse: clients.Client{},
updateErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "update client secret with failed to remove cache",
client: client,
newSecret: "newSecret",
session: smqauthn.Session{UserID: validID},
updateSecretResponse: clients.Client{
ID: client.ID,
Credentials: clients.Credentials{
Identity: client.Credentials.Identity,
Secret: "newSecret",
},
},
removeErr: repoerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateErr)
var cacheCall *mock.Call
if tc.updateErr == nil {
cacheCall = cache.On("Remove", context.Background(), tc.updateSecretResponse.ID).Return(tc.removeErr)
}
updatedClient, err := svc.UpdateSecret(context.Background(), tc.session, tc.client.ID, tc.newSecret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateSecretResponse, updatedClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateSecretResponse, updatedClient))
repoCall.Unset()
if cacheCall != nil {
cacheCall.Unset()
}
})
}
}
func TestEnable(t *testing.T) {
svc := newService()
enabledClient1 := clients.Client{ID: ID, Credentials: clients.Credentials{Identity: "client1@example.com", Secret: "password"}, Status: clients.EnabledStatus}
disabledClient1 := clients.Client{ID: ID, Credentials: clients.Credentials{Identity: "client3@example.com", Secret: "password"}, Status: clients.DisabledStatus}
endisabledClient1 := disabledClient1
endisabledClient1.Status = clients.EnabledStatus
cases := []struct {
desc string
id string
session smqauthn.Session
client clients.Client
changeStatusResponse clients.Client
retrieveByIDResponse clients.Client
changeStatusErr error
retrieveIDErr error
err error
}{
{
desc: "enable disabled client",
id: disabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: disabledClient1,
changeStatusResponse: endisabledClient1,
retrieveByIDResponse: disabledClient1,
err: nil,
},
{
desc: "enable disabled client with failed to update repo",
id: disabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: disabledClient1,
changeStatusResponse: clients.Client{},
retrieveByIDResponse: disabledClient1,
changeStatusErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "enable enabled client",
id: enabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: enabledClient1,
changeStatusResponse: enabledClient1,
retrieveByIDResponse: enabledClient1,
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable non-existing client",
id: wrongID,
session: smqauthn.Session{UserID: validID},
client: clients.Client{},
changeStatusResponse: clients.Client{},
retrieveByIDResponse: clients.Client{},
retrieveIDErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveIDErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
_, err := svc.Enable(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestDisable(t *testing.T) {
svc := newService()
enabledClient1 := clients.Client{ID: ID, Credentials: clients.Credentials{Identity: "client1@example.com", Secret: "password"}, Status: clients.EnabledStatus}
disabledClient1 := clients.Client{ID: ID, Credentials: clients.Credentials{Identity: "client3@example.com", Secret: "password"}, Status: clients.DisabledStatus}
disenabledClient1 := enabledClient1
disenabledClient1.Status = clients.DisabledStatus
cases := []struct {
desc string
id string
session smqauthn.Session
client clients.Client
changeStatusResponse clients.Client
retrieveByIDResponse clients.Client
changeStatusErr error
retrieveIDErr error
removeErr error
err error
}{
{
desc: "disable enabled client",
id: enabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: enabledClient1,
changeStatusResponse: disenabledClient1,
retrieveByIDResponse: enabledClient1,
err: nil,
},
{
desc: "disable client with failed to update repo",
id: enabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: enabledClient1,
changeStatusResponse: clients.Client{},
retrieveByIDResponse: enabledClient1,
changeStatusErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "disable disabled client",
id: disabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: disabledClient1,
changeStatusResponse: clients.Client{},
retrieveByIDResponse: disabledClient1,
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable non-existing client",
id: wrongID,
client: clients.Client{},
session: smqauthn.Session{UserID: validID},
changeStatusResponse: clients.Client{},
retrieveByIDResponse: clients.Client{},
retrieveIDErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "disable client with failed to remove from cache",
id: enabledClient1.ID,
session: smqauthn.Session{UserID: validID},
client: disabledClient1,
changeStatusResponse: disenabledClient1,
retrieveByIDResponse: enabledClient1,
removeErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveIDErr)
repoCall1 := repo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
repoCall2 := cache.On("Remove", mock.Anything, mock.Anything).Return(tc.removeErr)
_, err := svc.Disable(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
func TestDelete(t *testing.T) {
svc := newService()
client := clients.Client{
ID: testsutil.GenerateUUID(t),
}
cases := []struct {
desc string
clientID string
checkConnectionsRes bool
checkConnectionsErr error
removeConnectionsErr error
changeStatusErr error
deletePoliciesErr error
removeErr error
deleteErr error
err error
}{
{
desc: "Delete client without connections successfully",
clientID: client.ID,
err: nil,
},
{
desc: "Delete client with connections",
clientID: client.ID,
checkConnectionsRes: true,
err: nil,
},
{
desc: "Delete client with failed to check connections",
clientID: client.ID,
checkConnectionsErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "Delete client with failed to remove connections",
clientID: client.ID,
checkConnectionsRes: true,
removeConnectionsErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "Delete cliet with failed to remove from cache",
clientID: client.ID,
removeErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "Delete client with failed to change status",
clientID: client.ID,
changeStatusErr: svcerr.ErrNotFound,
err: svcerr.ErrRemoveEntity,
},
{
desc: "Delete client with failed to delete policies",
clientID: client.ID,
deletePoliciesErr: svcerr.ErrNotFound,
err: svcerr.ErrDeletePolicies,
},
{
desc: "Delete client with failed to delete",
clientID: client.ID,
deleteErr: svcerr.ErrNotFound,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("DoesClientHaveConnections", context.Background(), mock.Anything).Return(tc.checkConnectionsRes, tc.checkConnectionsErr)
channelsCall := chgRPCClient.On("RemoveClientConnections", context.Background(), &grpcChannelsV1.RemoveClientConnectionsReq{ClientId: tc.clientID}).Return(&grpcChannelsV1.RemoveClientConnectionsRes{}, tc.removeConnectionsErr)
repoCall1 := cache.On("Remove", mock.Anything, tc.clientID).Return(tc.removeErr)
repoCall2 := repo.On("ChangeStatus", context.Background(), clients.Client{ID: tc.clientID, Status: clients.DeletedStatus}).Return(client, tc.changeStatusErr)
repoCall3 := repo.On("RetrieveEntitiesRolesActionsMembers", context.Background(), []string{tc.clientID}).Return([]roles.EntityActionRole{}, []roles.EntityMemberRole{}, nil)
policyCall1 := pService.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
policyCall2 := pService.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePoliciesErr)
repoCall4 := repo.On("Delete", context.Background(), []string{tc.clientID}).Return(tc.deleteErr)
err := svc.Delete(context.Background(), smqauthn.Session{}, tc.clientID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
policyCall1.Unset()
repoCall2.Unset()
channelsCall.Unset()
repoCall3.Unset()
repoCall4.Unset()
policyCall2.Unset()
})
}
}
func TestSetParentGroup(t *testing.T) {
svc := newService()
parentedClient := client
parentedClient.ParentGroup = validID
cparentedClient := client
cparentedClient.ParentGroup = testsutil.GenerateUUID(t)
cases := []struct {
desc string
clientID string
parentGroupID string
session smqauthn.Session
retrieveByIDResp clients.Client
retrieveByIDErr error
retrieveEntityResp *grpcCommonV1.RetrieveEntityRes
retrieveEntityErr error
addPoliciesErr error
deletePoliciesErr error
setParentGroupErr error
err error
}{
{
desc: "set parent group successfully",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: validID,
Status: uint32(clients.EnabledStatus),
},
},
err: nil,
},
{
desc: "set parent group with failed to retrieve client",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: clients.Client{},
retrieveByIDErr: svcerr.ErrNotFound,
err: svcerr.ErrUpdateEntity,
},
{
desc: "set parent group with parent already set",
clientID: parentedClient.ID,
parentGroupID: validID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: parentedClient,
err: svcerr.ErrConflict,
},
{
desc: "set parent group of client with existing parent group",
clientID: cparentedClient.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: cparentedClient,
err: svcerr.ErrConflict,
},
{
desc: "set parent group with failed to retrieve entity",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityErr: svcerr.ErrAuthorization,
err: svcerr.ErrUpdateEntity,
},
{
desc: "set parent group with parent group from different domain",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: testsutil.GenerateUUID(t),
Status: uint32(clients.EnabledStatus),
},
},
err: svcerr.ErrUpdateEntity,
},
{
desc: "set parent group with disabled parent group",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: validID,
Status: uint32(clients.DisabledStatus),
},
},
err: svcerr.ErrUpdateEntity,
},
{
desc: "set parent group with failed to add policies",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: validID,
Status: uint32(clients.EnabledStatus),
},
},
addPoliciesErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrAddPolicies,
},
{
desc: "set parent group with failed to set parent group",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: validID,
Status: uint32(clients.EnabledStatus),
},
},
setParentGroupErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "set parent group with failed to set parent group and failed rollback",
clientID: client.ID,
parentGroupID: testsutil.GenerateUUID(t),
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: client,
retrieveEntityResp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: testsutil.GenerateUUID(t),
DomainId: validID,
Status: uint32(clients.EnabledStatus),
},
},
setParentGroupErr: svcerr.ErrUpdateEntity,
deletePoliciesErr: svcerr.ErrAuthorization,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pols := []policysvc.Policy{
{
Domain: tc.session.DomainID,
SubjectType: policysvc.GroupType,
Subject: tc.parentGroupID,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.ClientType,
Object: tc.clientID,
},
}
repoCall := repo.On("RetrieveByID", context.Background(), tc.clientID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
groupsCall := gpgRPCClient.On("RetrieveEntity", context.Background(), &grpcCommonV1.RetrieveEntityReq{Id: tc.parentGroupID}).Return(tc.retrieveEntityResp, tc.retrieveEntityErr)
policyCall := pService.On("AddPolicies", context.Background(), pols).Return(tc.addPoliciesErr)
policyCall1 := pService.On("DeletePolicies", context.Background(), pols).Return(tc.deletePoliciesErr)
repoCall2 := repo.On("SetParentGroup", context.Background(), mock.Anything).Return(tc.setParentGroupErr)
err := svc.SetParentGroup(context.Background(), tc.session, tc.parentGroupID, tc.clientID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
groupsCall.Unset()
policyCall.Unset()
repoCall2.Unset()
policyCall1.Unset()
})
}
}
func TestRemoveParentGroup(t *testing.T) {
svc := newService()
parentedGroup := client
parentedGroup.ParentGroup = validID
cases := []struct {
desc string
clientID string
session smqauthn.Session
retrieveByIDResp clients.Client
retrieveByIDErr error
deletePoliciesErr error
addPoliciesErr error
removeParentGroupErr error
err error
}{
{
desc: "remove parent group successfully",
clientID: parentedGroup.ID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: parentedGroup,
err: nil,
},
{
desc: "remove parent group with failed to retrieve client",
clientID: parentedGroup.ID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: clients.Client{},
retrieveByIDErr: svcerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
{
desc: "remove parent group with failed to delete policies",
clientID: parentedGroup.ID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: parentedGroup,
deletePoliciesErr: svcerr.ErrAuthorization,
err: svcerr.ErrDeletePolicies,
},
{
desc: "remove parent group with failed to remove parent group",
clientID: parentedGroup.ID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: parentedGroup,
removeParentGroupErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "remove parent group with failed to remove parent group and failed to add policies",
clientID: parentedGroup.ID,
session: smqauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID},
retrieveByIDResp: parentedGroup,
removeParentGroupErr: svcerr.ErrUpdateEntity,
addPoliciesErr: svcerr.ErrUpdateEntity,
err: apiutil.ErrRollbackTx,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
pols := []policysvc.Policy{
{
Domain: tc.session.DomainID,
SubjectType: policysvc.GroupType,
Subject: tc.retrieveByIDResp.ParentGroup,
Relation: policysvc.ParentGroupRelation,
ObjectType: policysvc.ClientType,
Object: tc.clientID,
},
}
repoCall := repo.On("RetrieveByID", context.Background(), tc.clientID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
policyCall := pService.On("DeletePolicies", context.Background(), pols).Return(tc.deletePoliciesErr)
policyCall1 := pService.On("AddPolicies", context.Background(), pols).Return(tc.addPoliciesErr)
repoCall2 := repo.On("RemoveParentGroup", context.Background(), mock.Anything).Return(tc.removeParentGroupErr)
err := svc.RemoveParentGroup(context.Background(), tc.session, tc.clientID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
policyCall.Unset()
repoCall2.Unset()
policyCall1.Unset()
})
}
}