// Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 package users_test import ( "context" "fmt" "strings" "testing" "time" grpcTokenV1 "github.com/absmach/magistrala/api/grpc/token/v1" smqauth "github.com/absmach/magistrala/auth" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/authn" "github.com/absmach/magistrala/pkg/errors" repoerr "github.com/absmach/magistrala/pkg/errors/repository" svcerr "github.com/absmach/magistrala/pkg/errors/service" policymocks "github.com/absmach/magistrala/pkg/policies/mocks" "github.com/absmach/magistrala/pkg/uuid" "github.com/absmach/magistrala/users" "github.com/absmach/magistrala/users/hasher" "github.com/absmach/magistrala/users/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) var ( idProvider = uuid.New() phasher = hasher.New() secret = "strongsecret" validCMetadata = users.Metadata{"role": "user"} userID = "d8dd12ef-aa2a-43fe-8ef2-2e4fe514360f" user = users.User{ ID: userID, FirstName: "firstname", LastName: "lastname", Tags: []string{"tag1", "tag2"}, Credentials: users.Credentials{Username: "username", Secret: secret}, Email: "useremail@email.com", Metadata: validCMetadata, PrivateMetadata: validCMetadata, Status: users.EnabledStatus, } basicUser = users.User{ Credentials: users.Credentials{ Username: "username", }, ID: userID, FirstName: "firstname", LastName: "lastname", } validToken = "token" validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" wrongID = testsutil.GenerateUUID(&testing.T{}) errHashPassword = errors.New("generate hash from password failed") ) func newService() (users.Service, *authmocks.TokenServiceClient, *mocks.Repository, *policymocks.Service, *mocks.Emailer) { cRepo := new(mocks.Repository) policies := new(policymocks.Service) e := new(mocks.Emailer) tokenClient := new(authmocks.TokenServiceClient) return users.NewService(tokenClient, cRepo, policies, e, phasher, idProvider), tokenClient, cRepo, policies, e } func newServiceMinimal() (users.Service, *mocks.Repository) { cRepo := new(mocks.Repository) policies := new(policymocks.Service) e := new(mocks.Emailer) tokenUser := new(authmocks.TokenServiceClient) return users.NewService(tokenUser, cRepo, policies, e, phasher, idProvider), cRepo } func TestRegister(t *testing.T) { svc, _, cRepo, policies, _ := newService() cases := []struct { desc string user users.User addPoliciesResponseErr error deletePoliciesResponseErr error saveErr error err error }{ { desc: "register new user successfully", user: user, err: nil, }, { desc: "register existing user", user: user, saveErr: repoerr.ErrConflict, err: repoerr.ErrConflict, }, { desc: "register a new enabled user with name", user: users.User{ FirstName: "userWithName", Email: "newuserwithname@example.com", Credentials: users.Credentials{ Secret: secret, }, Status: users.EnabledStatus, }, err: nil, }, { desc: "register a new disabled user with name", user: users.User{ FirstName: "userWithName", Email: "newuserwithname@example.com", Credentials: users.Credentials{ Secret: secret, }, }, err: nil, }, { desc: "register a new user with all fields", user: users.User{ FirstName: "newuserwithallfields", Tags: []string{"tag1", "tag2"}, Email: "newuserwithallfields@example.com", Credentials: users.Credentials{ Secret: secret, }, PrivateMetadata: users.Metadata{ "name": "newuserwithallfields", }, Metadata: users.Metadata{ "name": "newuserwithallfields", }, Status: users.EnabledStatus, }, err: nil, }, { desc: "register a new user with missing email", user: users.User{ FirstName: "userWithMissingEmail", Credentials: users.Credentials{ Secret: secret, }, }, saveErr: errors.ErrMalformedEntity, err: errors.ErrMalformedEntity, }, { desc: "register a new user with missing secret", user: users.User{ FirstName: "userWithMissingSecret", Email: "userwithmissingsecret@example.com", Credentials: users.Credentials{ Secret: "", }, }, err: nil, }, { desc: " register a user with a secret that is too long", user: users.User{ FirstName: "userWithLongSecret", Email: "userwithlongsecret@example.com", Credentials: users.Credentials{ Secret: strings.Repeat("a", 73), }, }, err: errHashPassword, }, { desc: "register a new user with invalid status", user: users.User{ FirstName: "userWithInvalidStatus", Email: "user with invalid status", Credentials: users.Credentials{ Secret: secret, }, Status: users.AllStatus, }, err: svcerr.ErrInvalidStatus, }, { desc: "register a new user with invalid role", user: users.User{ FirstName: "userWithInvalidRole", Email: "userwithinvalidrole@example.com", Credentials: users.Credentials{ Secret: secret, }, Role: 2, }, err: svcerr.ErrInvalidRole, }, { desc: "register a new user with failed to add policies with err", user: users.User{ FirstName: "userWithFailedToAddPolicies", Email: "userwithfailedpolicies@example.com", Credentials: users.Credentials{ Secret: secret, }, Role: users.AdminRole, }, addPoliciesResponseErr: svcerr.ErrAddPolicies, err: svcerr.ErrAddPolicies, }, { desc: "register a new user with failed to delete policies with err", user: users.User{ FirstName: "userWithFailedToDeletePolicies", Email: "userwithfailedtodelete@example.com", Credentials: users.Credentials{ Secret: secret, }, Role: users.AdminRole, }, deletePoliciesResponseErr: svcerr.ErrConflict, saveErr: repoerr.ErrConflict, err: svcerr.ErrConflict, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr) policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr) repoCall := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr) expected, err := svc.Register(context.Background(), authn.Session{}, tc.user, true) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { tc.user.ID = expected.ID tc.user.CreatedAt = expected.CreatedAt tc.user.UpdatedAt = expected.UpdatedAt tc.user.Credentials.Secret = expected.Credentials.Secret tc.user.UpdatedBy = expected.UpdatedBy assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected)) ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall.Unset() policyCall.Unset() policyCall1.Unset() }) } svc, _, cRepo, policies, _ = newService() cases2 := []struct { desc string user users.User session authn.Session addPoliciesResponseErr error deletePoliciesResponseErr error saveErr error checkSuperAdminErr error err error }{ { desc: "register new user successfully as admin", user: user, session: authn.Session{UserID: validID, SuperAdmin: true}, err: nil, }, { desc: "register a new user as admin with failed check on super admin", user: user, session: authn.Session{UserID: validID, SuperAdmin: false}, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, } for _, tc := range cases2 { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr) policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr) repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr) expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { tc.user.ID = expected.ID tc.user.CreatedAt = expected.CreatedAt tc.user.UpdatedAt = expected.UpdatedAt tc.user.Credentials.Secret = expected.Credentials.Secret tc.user.UpdatedBy = expected.UpdatedBy assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected)) ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall1.Unset() policyCall.Unset() policyCall1.Unset() repoCall.Unset() } } func TestViewUser(t *testing.T) { svc, cRepo := newServiceMinimal() cases := []struct { desc string token string reqUserID string userID string retrieveByIDResponse users.User response users.User identifyErr error authorizeErr error retrieveByIDErr error checkSuperAdminErr error err error }{ { desc: "view user as normal user successfully", retrieveByIDResponse: user, response: user, token: validToken, reqUserID: user.ID, userID: user.ID, err: nil, checkSuperAdminErr: svcerr.ErrAuthorization, }, { desc: "view user as normal user with failed to retrieve user", retrieveByIDResponse: users.User{}, token: validToken, reqUserID: user.ID, userID: user.ID, retrieveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrNotFound, checkSuperAdminErr: svcerr.ErrAuthorization, }, { desc: "view user as admin user successfully", retrieveByIDResponse: user, response: user, token: validToken, reqUserID: user.ID, userID: user.ID, err: nil, }, { desc: "view user as admin user with failed check on super admin", token: validToken, retrieveByIDResponse: basicUser, response: basicUser, reqUserID: user.ID, userID: "", checkSuperAdminErr: svcerr.ErrAuthorization, err: nil, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) rUser, err := svc.View(context.Background(), authn.Session{UserID: tc.reqUserID}, tc.userID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) tc.response.Credentials.Secret = "" assert.Equal(t, tc.response, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rUser)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.userID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } repoCall1.Unset() repoCall.Unset() }) } } func TestListUsers(t *testing.T) { svc, cRepo := newServiceMinimal() cases := []struct { desc string token string page users.Page retrieveAllResponse users.UsersPage response users.UsersPage size uint64 retrieveAllErr error superAdminErr error err error }{ { desc: "list clients as admin successfully", page: users.Page{ Total: 1, }, retrieveAllResponse: users.UsersPage{ Page: users.Page{ Total: 1, }, Users: []users.User{user}, }, response: users.UsersPage{ Page: users.Page{ Total: 1, }, Users: []users.User{user}, }, token: validToken, err: nil, }, { desc: "list clients as admin with failed to retrieve clients", page: users.Page{ Total: 1, }, retrieveAllResponse: users.UsersPage{}, token: validToken, retrieveAllErr: repoerr.ErrNotFound, err: svcerr.ErrViewEntity, }, { desc: "list clients as admin with failed check on super admin", page: users.Page{ Total: 1, }, token: validToken, superAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "list clients as normal user with failed to retrieve clients", page: users.Page{ Total: 1, }, retrieveAllResponse: users.UsersPage{}, token: validToken, retrieveAllErr: repoerr.ErrNotFound, err: svcerr.ErrViewEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.superAdminErr) repoCall1 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) page, err := svc.ListUsers(context.Background(), authn.Session{UserID: user.ID}, tc.page) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() }) } } func TestSearchUsers(t *testing.T) { svc, cRepo := newServiceMinimal() cases := []struct { desc string token string page users.Page response users.UsersPage responseErr error err error }{ { desc: "search clients with valid token", token: validToken, page: users.Page{Offset: 0, FirstName: "username", Limit: 100}, response: users.UsersPage{ Page: users.Page{Total: 1, Offset: 0, Limit: 100}, Users: []users.User{user}, }, }, { desc: "search clients with id", token: validToken, page: users.Page{Offset: 0, Id: "d8dd12ef-aa2a-43fe-8ef2-2e4fe514360f", Limit: 100}, response: users.UsersPage{ Page: users.Page{Total: 1, Offset: 0, Limit: 100}, Users: []users.User{user}, }, }, { desc: "search clients with random name", token: validToken, page: users.Page{Offset: 0, FirstName: "randomname", Limit: 100}, response: users.UsersPage{ Page: users.Page{Total: 0, Offset: 0, Limit: 100}, Users: []users.User{}, }, }, { desc: "search clients with repo failed", token: validToken, page: users.Page{Offset: 0, FirstName: "randomname", Limit: 100}, response: users.UsersPage{ Page: users.Page{Total: 0, Offset: 0, Limit: 0}, }, responseErr: repoerr.ErrViewEntity, err: svcerr.ErrViewEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("SearchUsers", context.Background(), mock.Anything).Return(tc.response, tc.responseErr) page, err := svc.SearchUsers(context.Background(), tc.page) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) repoCall.Unset() }) } } func TestUpdateUser(t *testing.T) { svc, cRepo := newServiceMinimal() user1 := user user2 := user updateFirstName := "Updated user" user1.FirstName = updateFirstName updatedMetadata := users.Metadata{"role": "test"} invalidMetadata := users.Metadata{"role": make(chan int)} user2.PrivateMetadata = updatedMetadata user2.Metadata = updatedMetadata adminID := testsutil.GenerateUUID(t) cases := []struct { desc string userID string userReq users.UserReq session authn.Session updateResponse users.User retrieveByIDResp users.User retrieveByIDErr error token string updateErr error checkSuperAdminErr error err error }{ { desc: "update user name successfully as normal user", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: user1.ID}, updateResponse: user1, retrieveByIDResp: user1, token: validToken, err: nil, }, { desc: "update private metadata successfully as normal user", userID: user2.ID, userReq: users.UserReq{ PrivateMetadata: &updatedMetadata, }, session: authn.Session{UserID: user2.ID}, updateResponse: user2, token: validToken, err: nil, }, { desc: "update private metadata with repo error", userID: user2.ID, userReq: users.UserReq{ PrivateMetadata: &invalidMetadata, }, session: authn.Session{UserID: user2.ID}, updateResponse: users.User{}, token: validToken, updateErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update metadata successfully as normal user", userID: user2.ID, userReq: users.UserReq{ Metadata: &updatedMetadata, }, session: authn.Session{UserID: user2.ID}, updateResponse: user2, retrieveByIDResp: user2, token: validToken, err: nil, }, { desc: "update metadata with repo error", userID: user2.ID, userReq: users.UserReq{ Metadata: &invalidMetadata, }, session: authn.Session{UserID: user2.ID}, updateResponse: users.User{}, retrieveByIDResp: user2, token: validToken, updateErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user name as normal user with repo error on update", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: user1.ID}, updateResponse: users.User{}, retrieveByIDResp: user1, token: validToken, updateErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user name as admin successfully", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateResponse: user1, retrieveByIDResp: user1, token: validToken, err: nil, }, { desc: "update user private metadata as admin successfully", userID: user2.ID, userReq: users.UserReq{ PrivateMetadata: &updatedMetadata, }, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateResponse: user2, retrieveByIDResp: user2, token: validToken, err: nil, }, { desc: "update user with failed check on super admin", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: adminID}, token: validToken, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update user name as admin with repo error on update", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateResponse: users.User{}, retrieveByIDResp: user1, token: validToken, updateErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user first name with external auth provider should fail", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: user1.ID}, retrieveByIDResp: users.User{ ID: user1.ID, AuthProvider: "google", }, token: validToken, err: svcerr.ErrExternalAuthProviderCouldNotUpdate, }, { desc: "update user last name with external auth provider should fail", userID: user1.ID, userReq: users.UserReq{ LastName: &updateFirstName, }, session: authn.Session{UserID: user1.ID}, retrieveByIDResp: users.User{ ID: user1.ID, AuthProvider: "google", }, token: validToken, err: svcerr.ErrExternalAuthProviderCouldNotUpdate, }, { desc: "update user privatemetadata with external auth provider should succeed", userID: user2.ID, userReq: users.UserReq{ PrivateMetadata: &updatedMetadata, }, session: authn.Session{UserID: user2.ID}, retrieveByIDResp: users.User{ ID: user2.ID, AuthProvider: "google", PrivateMetadata: updatedMetadata, }, updateResponse: users.User{ ID: user2.ID, AuthProvider: "google", PrivateMetadata: updatedMetadata, }, token: validToken, err: nil, }, { desc: "update user with retrieve by id error", userID: user1.ID, userReq: users.UserReq{ FirstName: &updateFirstName, }, session: authn.Session{UserID: user1.ID}, retrieveByIDErr: repoerr.ErrNotFound, token: validToken, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr) repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateResponse, tc.updateErr) updatedUser, err := svc.Update(context.Background(), tc.session, tc.userID, tc.userReq) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedUser)) if tc.err == nil { ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() }) } } func TestUpdateTags(t *testing.T) { svc, cRepo := newServiceMinimal() updateTags := []string{"tag1", "tag2"} user.Tags = updateTags adminID := testsutil.GenerateUUID(t) cases := []struct { desc string userID string userReq users.UserReq session authn.Session updateUserTagsResponse users.User updateUserTagsErr error checkSuperAdminErr error err error }{ { desc: "update user tags as normal user successfully", userID: user.ID, userReq: users.UserReq{Tags: &updateTags}, session: authn.Session{UserID: user.ID}, updateUserTagsResponse: user, err: nil, }, { desc: "update user tags as normal user with repo error on update", userID: user.ID, userReq: users.UserReq{Tags: &updateTags}, session: authn.Session{UserID: user.ID}, updateUserTagsResponse: users.User{}, updateUserTagsErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user tags as admin successfully", userID: user.ID, userReq: users.UserReq{Tags: &updateTags}, session: authn.Session{UserID: adminID, SuperAdmin: true}, err: nil, }, { desc: "update user tags as admin with failed check on super admin", userID: user.ID, userReq: users.UserReq{Tags: &updateTags}, session: authn.Session{UserID: adminID}, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update user tags as admin with repo error on update", userID: user.ID, userReq: users.UserReq{Tags: &updateTags}, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateUserTagsResponse: users.User{}, updateUserTagsErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateUserTagsResponse, tc.updateUserTagsErr) updatedUser, err := svc.UpdateTags(context.Background(), tc.session, tc.userID, tc.userReq) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateUserTagsResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUserTagsResponse, updatedUser)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() }) } } func TestUpdateRole(t *testing.T) { svc, _, cRepo, policies, _ := newService() user2 := user user.Role = users.AdminRole user2.Role = users.UserRole cases := []struct { desc string user users.User session authn.Session updateRoleResponse users.User deletePolicyErr error addPolicyErr error updateRoleErr error checkSuperAdminErr error err error }{ { desc: "update user role successfully", user: user, session: authn.Session{UserID: validID, SuperAdmin: true}, updateRoleResponse: user, err: nil, }, { desc: "update user role with failed check on super admin", user: user, session: authn.Session{UserID: validID, SuperAdmin: false}, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update user role with failed to add policies", user: user, session: authn.Session{UserID: validID, SuperAdmin: true}, addPolicyErr: errors.ErrMalformedEntity, err: svcerr.ErrAddPolicies, }, { desc: "update user role to user role successfully ", user: user2, session: authn.Session{UserID: validID, SuperAdmin: true}, updateRoleResponse: user2, err: nil, }, { desc: "update user role to user role with failed to delete policies", user: user2, session: authn.Session{UserID: validID, SuperAdmin: true}, deletePolicyErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update user role to user role with failed to delete policies with error", user: user2, session: authn.Session{UserID: validID, SuperAdmin: true}, deletePolicyErr: svcerr.ErrMalformedEntity, err: svcerr.ErrDeletePolicies, }, { desc: "Update user with failed repo update and roll back", user: user, session: authn.Session{UserID: validID, SuperAdmin: true}, updateRoleErr: svcerr.ErrAuthentication, err: svcerr.ErrAuthentication, }, { desc: "Update user with failed repo update and failedroll back", user: user, session: authn.Session{UserID: validID, SuperAdmin: true}, deletePolicyErr: svcerr.ErrAuthorization, updateRoleErr: svcerr.ErrAuthentication, err: svcerr.ErrAuthentication, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr) policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr) repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr) updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() policyCall.Unset() policyCall1.Unset() repoCall1.Unset() }) } } func TestUpdateSecret(t *testing.T) { svc, authUser, cRepo, _, _ := newService() newSecret := "newstrongSecret" rUser := user rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) responseUser := user responseUser.Credentials.Secret = newSecret cases := []struct { desc string oldSecret string newSecret string session authn.Session retrieveByIDResponse users.User retrieveByEmailResponse users.User updateSecretResponse users.User issueResponse *grpcTokenV1.Token response users.User retrieveByIDErr error retrieveByEmailErr error updateSecretErr error issueErr error err error }{ { desc: "update user secret with valid token", oldSecret: user.Credentials.Secret, newSecret: newSecret, session: authn.Session{UserID: user.ID}, retrieveByEmailResponse: rUser, retrieveByIDResponse: user, updateSecretResponse: responseUser, issueResponse: &grpcTokenV1.Token{AccessToken: validToken}, response: responseUser, err: nil, }, { desc: "update user secret with failed to retrieve user by ID", oldSecret: user.Credentials.Secret, newSecret: newSecret, session: authn.Session{UserID: user.ID}, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "update user secret with failed to retrieve user by email", oldSecret: user.Credentials.Secret, newSecret: newSecret, session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveByEmailResponse: users.User{}, retrieveByEmailErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "update user secret with invalid old secret", oldSecret: "invalid", newSecret: newSecret, session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveByEmailResponse: rUser, err: svcerr.ErrLogin, }, { desc: "update user secret with too long new secret", oldSecret: user.Credentials.Secret, newSecret: strings.Repeat("a", 73), session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveByEmailResponse: rUser, err: errHashPassword, }, { desc: "update user secret with failed to update secret", oldSecret: user.Credentials.Secret, newSecret: newSecret, session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveByEmailResponse: rUser, updateSecretResponse: users.User{}, updateSecretErr: repoerr.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByID", context.Background(), user.ID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) repoCall1 := cRepo.On("RetrieveByUsername", context.Background(), user.Credentials.Username).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr) repoCall2 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr) authCall := authUser.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr) updatedUser, err := svc.UpdateSecret(context.Background(), tc.session, tc.oldSecret, tc.newSecret) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.response, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, updatedUser)) if tc.err == nil { ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.response.ID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = repoCall1.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.response.Credentials.Username) assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc)) ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() authCall.Unset() }) } } func TestUpdateEmail(t *testing.T) { svc, cRepo := newServiceMinimal() user2 := user user2.Email = "user2@example.com" cases := []struct { desc string email string token string reqUserID string id string updateEmailResponse users.User updateEmailErr error checkSuperAdminErr error err error }{ { desc: "update user as normal user successfully", email: "user2-update-1@example.com", token: validToken, reqUserID: user.ID, id: user.ID, updateEmailResponse: user2, err: nil, }, { desc: "update to same email as normal user successfully", email: "user2-update-1@example.com", token: validToken, reqUserID: user.ID, id: user.ID, updateEmailResponse: user2, err: nil, }, { desc: "update user email as normal user with repo error on update", email: "user2-update-2@example.com", token: validToken, reqUserID: user.ID, id: user.ID, updateEmailResponse: users.User{}, updateEmailErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user email as admin successfully", email: "user2-update-3@example.com", token: validToken, id: user.ID, err: nil, }, { desc: "update user email as admin with repo error on update", email: "user2-update-4@exmaple.com", token: validToken, reqUserID: user.ID, id: user.ID, updateEmailResponse: users.User{}, updateEmailErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update user as admin user with failed check on super admin", email: "user2-update-5@exmaple.com", token: validToken, reqUserID: user.ID, id: "", updateEmailResponse: users.User{}, updateEmailErr: errors.ErrMalformedEntity, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr) repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr) updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser)) if tc.err == nil && user2.Email != tc.email { ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) user2.Email = tc.email } repoCall.Unset() repocall2.Unset() repoCall1.Unset() }) } } func TestUpdateProfilePicture(t *testing.T) { svc, cRepo := newServiceMinimal() updatedPicture := "https://example.com/profile.jpg" user.ProfilePicture = updatedPicture adminID := testsutil.GenerateUUID(t) cases := []struct { desc string userID string userReq users.UserReq session authn.Session updateProfilePicResponse users.User retrieveByIDResp users.User retrieveByIDErr error updateProfilePicErr error checkSuperAdminErr error err error }{ { desc: "update profile picture as normal user successfully", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: user.ID}, updateProfilePicResponse: user, retrieveByIDResp: user, err: nil, }, { desc: "update profile picture as normal user with repo error on update", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: user.ID}, updateProfilePicResponse: users.User{}, retrieveByIDResp: user, updateProfilePicErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update profile picture as admin successfully", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: adminID, SuperAdmin: true}, retrieveByIDResp: user, err: nil, }, { desc: "update profile picture as admin with failed check on super admin", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: adminID}, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update profile picture as admin with repo error on update", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateProfilePicResponse: users.User{}, retrieveByIDResp: user, updateProfilePicErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update profile picture with external auth provider", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: user.ID}, retrieveByIDResp: users.User{ ID: user.ID, AuthProvider: "google", }, err: svcerr.ErrExternalAuthProviderCouldNotUpdate, }, { desc: "update profile picture with retrieve by id error", userID: user.ID, userReq: users.UserReq{ProfilePicture: &updatedPicture}, session: authn.Session{UserID: user.ID}, retrieveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr) repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateProfilePicResponse, tc.updateProfilePicErr) updatedUser, err := svc.UpdateProfilePicture(context.Background(), tc.session, tc.userID, tc.userReq) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateProfilePicResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateProfilePicResponse, updatedUser)) if tc.err == nil { ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() }) } } func TestUpdateUsername(t *testing.T) { svc, cRepo := newServiceMinimal() nuser := user nuser.Credentials.Username = "newusername" adminID := testsutil.GenerateUUID(t) cases := []struct { desc string user users.User session authn.Session updateUsernameResponse users.User updateUsernameErr error checkSuperAdminErr error err error }{ { desc: "update username as normal user successfully", user: user, session: authn.Session{UserID: user.ID}, updateUsernameResponse: nuser, err: nil, }, { desc: "update username as normal user with repo error on update", user: user, session: authn.Session{UserID: user.ID}, updateUsernameResponse: users.User{}, updateUsernameErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, { desc: "update username as admin successfully", user: user, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateUsernameResponse: nuser, err: nil, }, { desc: "update username as admin with failed check on super admin", user: user, session: authn.Session{UserID: adminID}, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "update username as admin with repo error on update", user: user, session: authn.Session{UserID: adminID, SuperAdmin: true}, updateUsernameResponse: users.User{}, updateUsernameErr: errors.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("UpdateUsername", context.Background(), mock.Anything).Return(tc.updateUsernameResponse, tc.updateUsernameErr) updatedUser, err := svc.UpdateUsername(context.Background(), tc.session, tc.user.ID, tc.user.Credentials.Username) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateUsernameResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUsernameResponse, updatedUser)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "UpdateUsername", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateUsername was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() }) } } func TestEnableUser(t *testing.T) { svc, cRepo := newServiceMinimal() enabledUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user1@example.com", Secret: "password"}, Status: users.EnabledStatus} disabledUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user3@example.com", Secret: "password"}, Status: users.DisabledStatus} endisabledUser1 := disabledUser1 endisabledUser1.Status = users.EnabledStatus cases := []struct { desc string id string user users.User retrieveByIDResponse users.User changeStatusResponse users.User response users.User retrieveByIDErr error changeStatusErr error checkSuperAdminErr error err error }{ { desc: "enable disabled user", id: disabledUser1.ID, user: disabledUser1, retrieveByIDResponse: disabledUser1, changeStatusResponse: endisabledUser1, response: endisabledUser1, err: nil, }, { desc: "enable disabled user with normal user token", id: disabledUser1.ID, user: disabledUser1, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "enable disabled user with failed to retrieve user by ID", id: disabledUser1.ID, user: disabledUser1, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "enable already enabled user", id: enabledUser1.ID, user: enabledUser1, retrieveByIDResponse: enabledUser1, err: svcerr.ErrStatusAlreadyAssigned, }, { desc: "enable disabled user with failed to change status", id: disabledUser1.ID, user: disabledUser1, retrieveByIDResponse: disabledUser1, changeStatusResponse: users.User{}, changeStatusErr: repoerr.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr) _, err := svc.Enable(context.Background(), authn.Session{}, tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() }) } } func TestDisableUser(t *testing.T) { svc, cRepo := newServiceMinimal() enabledUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user1@example.com", Secret: "password"}, Status: users.EnabledStatus} disabledUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user3@example.com", Secret: "password"}, Status: users.DisabledStatus} disenabledUser1 := enabledUser1 disenabledUser1.Status = users.DisabledStatus cases := []struct { desc string id string user users.User retrieveByIDResponse users.User changeStatusResponse users.User response users.User retrieveByIDErr error changeStatusErr error checkSuperAdminErr error err error }{ { desc: "disable enabled user", id: enabledUser1.ID, user: enabledUser1, retrieveByIDResponse: enabledUser1, changeStatusResponse: disenabledUser1, response: disenabledUser1, err: nil, }, { desc: "disable enabled user with normal user token", id: enabledUser1.ID, user: enabledUser1, checkSuperAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, { desc: "disable enabled user with failed to retrieve user by ID", id: enabledUser1.ID, user: enabledUser1, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "disable already disabled user", id: disabledUser1.ID, user: disabledUser1, retrieveByIDResponse: disabledUser1, err: svcerr.ErrStatusAlreadyAssigned, }, { desc: "disable enabled user with failed to change status", id: enabledUser1.ID, user: enabledUser1, changeStatusResponse: users.User{}, changeStatusErr: repoerr.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr) _, err := svc.Disable(context.Background(), authn.Session{}, tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() }) } } func TestDeleteUser(t *testing.T) { svc, cRepo := newServiceMinimal() enabledUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user1@example.com", Secret: "password"}, Status: users.EnabledStatus} deletedUser1 := users.User{ID: testsutil.GenerateUUID(t), Credentials: users.Credentials{Username: "user3@example.com", Secret: "password"}, Status: users.DeletedStatus} disenabledUser1 := enabledUser1 disenabledUser1.Status = users.DeletedStatus cases := []struct { desc string id string session authn.Session user users.User retrieveByIDResponse users.User changeStatusResponse users.User response users.User retrieveByIDErr error changeStatusErr error checkSuperAdminErr error err error }{ { desc: "delete enabled user", id: enabledUser1.ID, user: enabledUser1, session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: enabledUser1, changeStatusResponse: disenabledUser1, response: disenabledUser1, err: nil, }, { desc: "delete enabled user with failed to retrieve user by ID", id: enabledUser1.ID, user: enabledUser1, session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "delete already deleted user", id: deletedUser1.ID, user: deletedUser1, session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: deletedUser1, err: svcerr.ErrStatusAlreadyAssigned, }, { desc: "delete enabled user with failed to change status", id: enabledUser1.ID, user: enabledUser1, session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: enabledUser1, changeStatusResponse: users.User{}, changeStatusErr: repoerr.ErrMalformedEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall2 := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) repoCall3 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) repoCall4 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr) err := svc.Delete(context.Background(), tc.session, tc.id) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if tc.err == nil { ok := repoCall3.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = repoCall4.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } repoCall2.Unset() repoCall3.Unset() repoCall4.Unset() }) } } func TestIssueToken(t *testing.T) { svc, auth, cRepo, _, _ := newService() rUser := user rUser2 := user rUser3 := user rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) rUser2.Credentials.Secret = "wrongsecret" rUser3.Credentials.Secret, _ = phasher.Hash("wrongsecret") cases := []struct { desc string user users.User retrieveByUsernameResponse users.User issueResponse *grpcTokenV1.Token retrieveByUsernameErr error issueErr error err error }{ { desc: "issue token for an existing user", user: user, retrieveByUsernameResponse: rUser, issueResponse: &grpcTokenV1.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"}, err: nil, }, { desc: "issue token for non-empty domain id", user: user, retrieveByUsernameResponse: rUser, issueResponse: &grpcTokenV1.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"}, err: nil, }, { desc: "issue token for a non-existing user", user: user, retrieveByUsernameResponse: users.User{}, retrieveByUsernameErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "issue token for a user with wrong secret", user: user, retrieveByUsernameResponse: rUser3, err: svcerr.ErrLogin, }, { desc: "issue token with empty domain id", user: user, retrieveByUsernameResponse: rUser, issueResponse: &grpcTokenV1.Token{}, issueErr: svcerr.ErrAuthentication, err: svcerr.ErrAuthentication, }, { desc: "issue token with grpc error", user: user, retrieveByUsernameResponse: rUser, issueResponse: &grpcTokenV1.Token{}, issueErr: svcerr.ErrAuthentication, err: svcerr.ErrAuthentication, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr) authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr) token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret, "") assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken())) assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken())) ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username) assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc)) ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}) assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc)) } authCall.Unset() repoCall.Unset() }) } } func TestRefreshToken(t *testing.T) { svc, authsvc, crepo, _, _ := newService() rUser := user rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) cases := []struct { desc string session authn.Session refreshResp *grpcTokenV1.Token refresErr error repoResp users.User repoErr error err error }{ { desc: "refresh token with refresh token for an existing user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, refreshResp: &grpcTokenV1.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"}, repoResp: rUser, err: nil, }, { desc: "refresh token with refresh token for empty domain id", session: authn.Session{UserID: validID}, refreshResp: &grpcTokenV1.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"}, repoResp: rUser, err: nil, }, { desc: "refresh token with access token for an existing user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, refreshResp: &grpcTokenV1.Token{}, refresErr: svcerr.ErrAuthentication, repoResp: rUser, err: svcerr.ErrAuthentication, }, { desc: "refresh token with refresh token for a non-existing client", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, repoErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "refresh token with refresh token for a disable user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, repoResp: users.User{Status: users.DisabledStatus}, err: svcerr.ErrAuthentication, }, { desc: "refresh token with empty domain id", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, refreshResp: &grpcTokenV1.Token{}, refresErr: svcerr.ErrAuthentication, repoResp: rUser, err: svcerr.ErrAuthentication, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { authCall := authsvc.On("Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr) repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr) token, err := svc.RefreshToken(context.Background(), tc.session, validToken) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken())) assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken())) ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}) assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc)) ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } authCall.Unset() repoCall.Unset() }) } } func TestRevokeRefreshToken(t *testing.T) { svc, authsvc, crepo, _, _ := newService() rUser := user rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) cases := []struct { desc string session authn.Session tokenID string revokeResp *grpcTokenV1.RevokeRes revokeErr error repoResp users.User repoErr error err error }{ { desc: "revoke refresh token successfully", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tokenID: validToken, revokeResp: &grpcTokenV1.RevokeRes{}, repoResp: rUser, err: nil, }, { desc: "revoke refresh token with empty domain id", session: authn.Session{UserID: validID}, tokenID: validToken, revokeResp: &grpcTokenV1.RevokeRes{}, repoResp: rUser, err: nil, }, { desc: "revoke refresh token for non-existing user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tokenID: validToken, repoErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "revoke refresh token for disabled user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tokenID: validToken, repoResp: users.User{Status: users.DisabledStatus}, err: svcerr.ErrAuthentication, }, { desc: "revoke refresh token with revoke service error", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tokenID: validToken, revokeResp: &grpcTokenV1.RevokeRes{}, revokeErr: svcerr.ErrAuthorization, repoResp: rUser, err: svcerr.ErrAuthorization, }, { desc: "revoke refresh token not found", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tokenID: validToken, revokeResp: &grpcTokenV1.RevokeRes{}, revokeErr: svcerr.ErrNotFound, repoResp: rUser, err: svcerr.ErrNotFound, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr) authCall := authsvc.On("Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}).Return(tc.revokeResp, tc.revokeErr) err := svc.RevokeRefreshToken(context.Background(), tc.session, tc.tokenID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = authCall.Parent.AssertCalled(t, "Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}) assert.True(t, ok, fmt.Sprintf("Revoke was not called on %s", tc.desc)) } repoCall.Unset() authCall.Unset() }) } } func TestListActiveRefreshTokens(t *testing.T) { svc, authsvc, crepo, _, _ := newService() rUser := user rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) cases := []struct { desc string session authn.Session listResp *grpcTokenV1.ListUserRefreshTokensRes listErr error repoResp users.User repoErr error expectedTokens int err error }{ { desc: "list active refresh tokens successfully", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, listResp: &grpcTokenV1.ListUserRefreshTokensRes{ RefreshTokens: []*grpcTokenV1.RefreshToken{ {Id: "token1", Description: "token1"}, {Id: "token2", Description: "token2"}, }, }, repoResp: rUser, expectedTokens: 2, err: nil, }, { desc: "list active refresh tokens with empty domain id", session: authn.Session{UserID: validID}, listResp: &grpcTokenV1.ListUserRefreshTokensRes{ RefreshTokens: []*grpcTokenV1.RefreshToken{ {Id: "token1", Description: "token1"}, }, }, repoResp: rUser, expectedTokens: 1, err: nil, }, { desc: "list active refresh tokens for non-existing user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, repoErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "list active refresh tokens for disabled user", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, repoResp: users.User{Status: users.DisabledStatus}, err: svcerr.ErrAuthentication, }, { desc: "list active refresh tokens with list service error", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, listResp: &grpcTokenV1.ListUserRefreshTokensRes{}, listErr: svcerr.ErrAuthentication, repoResp: rUser, err: svcerr.ErrAuthentication, }, { desc: "list active refresh tokens with empty list", session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, listResp: &grpcTokenV1.ListUserRefreshTokensRes{RefreshTokens: []*grpcTokenV1.RefreshToken{}}, repoResp: rUser, expectedTokens: 0, err: nil, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr) authCall := authsvc.On("ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}).Return(tc.listResp, tc.listErr) tokens, err := svc.ListActiveRefreshTokens(context.Background(), tc.session) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { assert.NotNil(t, tokens, fmt.Sprintf("%s: expected tokens not to be nil\n", tc.desc)) assert.Equal(t, tc.expectedTokens, len(tokens.GetRefreshTokens()), fmt.Sprintf("%s: expected %d tokens got %d\n", tc.desc, tc.expectedTokens, len(tokens.GetRefreshTokens()))) ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) ok = authCall.Parent.AssertCalled(t, "ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}) assert.True(t, ok, fmt.Sprintf("ListUserRefreshTokens was not called on %s", tc.desc)) } repoCall.Unset() authCall.Unset() }) } } func TestSendPasswordReset(t *testing.T) { svc, auth, cRepo, _, e := newService() cases := []struct { desc string email string retrieveByEmailResponse users.User issueResponse *grpcTokenV1.Token retrieveByEmailErr error issueErr error err error }{ { desc: "generate reset token for existing user", email: "existingemail@example.com", retrieveByEmailResponse: user, issueResponse: &grpcTokenV1.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"}, err: nil, }, { desc: "generate reset token for user with non-existing user", email: "example@example.com", retrieveByEmailResponse: users.User{ ID: testsutil.GenerateUUID(t), Email: "", }, retrieveByEmailErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "generate reset token with failed to issue token", email: "existingemail@example.com", retrieveByEmailResponse: user, issueResponse: &grpcTokenV1.Token{}, issueErr: svcerr.ErrAuthorization, err: svcerr.ErrAuthorization, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByEmail", context.Background(), tc.email).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr) authCall := auth.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr) svcCall := e.On("SendPasswordReset", []string{tc.email}, user.Credentials.Username, validToken).Return(tc.err) err := svc.SendPasswordReset(context.Background(), tc.email) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Parent.AssertCalled(t, "RetrieveByEmail", context.Background(), tc.email) repoCall.Unset() authCall.Unset() svcCall.Unset() }) } } func TestResetSecret(t *testing.T) { svc, cRepo := newServiceMinimal() user := users.User{ ID: "userID", Email: "test@example.com", Credentials: users.Credentials{ Secret: "Strongsecret", }, } cases := []struct { desc string newSecret string session authn.Session retrieveByIDResponse users.User updateSecretResponse users.User retrieveByIDErr error updateSecretErr error err error }{ { desc: "reset secret with successfully", newSecret: "newStrongSecret", session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: user, updateSecretResponse: users.User{ ID: "userID", Email: "test@example.com", Credentials: users.Credentials{ Secret: "newStrongSecret", }, }, err: nil, }, { desc: "reset secret with invalid ID", newSecret: "newStrongSecret", session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "reset secret with empty email", session: authn.Session{UserID: validID, SuperAdmin: true}, newSecret: "newStrongSecret", retrieveByIDResponse: users.User{ ID: "userID", Email: "", }, err: nil, }, { desc: "reset secret with failed to update secret", newSecret: "newStrongSecret", session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: user, updateSecretResponse: users.User{}, updateSecretErr: svcerr.ErrUpdateEntity, err: svcerr.ErrAuthorization, }, { desc: "reset secret with a too long secret", newSecret: strings.Repeat("strongSecret", 10), session: authn.Session{UserID: validID, SuperAdmin: true}, retrieveByIDResponse: user, err: errHashPassword, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) repoCall1 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr) err := svc.ResetSecret(context.Background(), tc.session, tc.newSecret) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if tc.err == nil { repoCall1.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything) repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), validID) } repoCall1.Unset() repoCall.Unset() }) } } func TestViewProfile(t *testing.T) { svc, cRepo := newServiceMinimal() user := users.User{ ID: "userID", Email: "existingEmail", Credentials: users.Credentials{ Secret: "Strongsecret", }, } cases := []struct { desc string user users.User session authn.Session retrieveByIDResponse users.User retrieveByIDErr error err error }{ { desc: "view profile successfully", user: user, session: authn.Session{UserID: validID}, retrieveByIDResponse: user, err: nil, }, { desc: "view profile with invalid ID", user: user, session: authn.Session{UserID: wrongID}, retrieveByIDResponse: users.User{}, retrieveByIDErr: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr) _, err := svc.ViewProfile(context.Background(), tc.session) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), mock.Anything) repoCall.Unset() }) } } func TestOAuthCallback(t *testing.T) { svc, _, cRepo, policies, _ := newService() cases := []struct { desc string user users.User retrieveByEmailResponse users.User retrieveByEmailErr error saveResponse users.User addPoliciesErr error updateVerifiedAtErr error err error }{ { desc: "oauth signin callback with already existing user", user: users.User{ Email: "test@example.com", }, retrieveByEmailResponse: users.User{ ID: testsutil.GenerateUUID(t), Role: users.UserRole, VerifiedAt: time.Now(), }, err: nil, }, { desc: "oauth signup callback with user not found", user: users.User{ Email: "test@example.com", }, retrieveByEmailErr: repoerr.ErrNotFound, saveResponse: users.User{ ID: testsutil.GenerateUUID(t), Role: users.UserRole, }, err: nil, }, { desc: "oauth signup callback with malformed entity", user: users.User{ Email: "test@example.com", }, retrieveByEmailErr: repoerr.ErrMalformedEntity, err: repoerr.ErrMalformedEntity, }, { desc: "oauth signup callback with failed to register user", user: users.User{ Email: "test@example.com", }, addPoliciesErr: svcerr.ErrAuthorization, retrieveByEmailErr: repoerr.ErrNotFound, err: svcerr.ErrAuthorization, }, { desc: "oauth signin callback with user not in the platform", user: users.User{ Email: "test@example.com", }, retrieveByEmailResponse: users.User{ ID: testsutil.GenerateUUID(t), Role: users.UserRole, }, err: nil, }, { desc: "oauth signin callback with failed update verified at", user: users.User{ Email: "test@example.com", }, retrieveByEmailResponse: users.User{ ID: testsutil.GenerateUUID(t), Role: users.UserRole, }, updateVerifiedAtErr: svcerr.ErrUpdateEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByEmail", context.Background(), tc.user.Email).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr) repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.saveResponse, nil) repoCall2 := cRepo.On("UpdateVerifiedAt", context.Background(), mock.MatchedBy(func(u users.User) bool { assert.NotEmpty(t, u.ID, "UpdateVerifiedAt must be called with non-empty user ID") return u.ID != "" })).Return(tc.retrieveByEmailResponse, tc.updateVerifiedAtErr) policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr) _, err := svc.OAuthCallback(context.Background(), tc.user) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Parent.AssertCalled(t, "RetrieveByEmail", context.Background(), tc.user.Email) repoCall.Unset() repoCall1.Unset() policyCall.Unset() _ = repoCall2 cRepo.ExpectedCalls = nil policies.ExpectedCalls = nil }) } } func TestSendVerification(t *testing.T) { svc, _, cRepo, _, e := newService() verifiedAt := time.Now().UTC() cases := []struct { desc string session authn.Session retrieveByIDResponse users.User retrieveByIDError error retrieveUserVerResponse users.UserVerification retrieveUserVerError error addUserVerError error sendVerificationEmailError error err error }{ { desc: "send verification email successfully", session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveUserVerError: repoerr.ErrNotFound, sendVerificationEmailError: nil, err: nil, }, { desc: "send verification email for already verified user", session: authn.Session{UserID: user.ID}, retrieveByIDResponse: users.User{VerifiedAt: verifiedAt}, err: svcerr.ErrUserAlreadyVerified, }, { desc: "send verification email for non-existing user", session: authn.Session{UserID: wrongID}, retrieveByIDError: repoerr.ErrNotFound, err: repoerr.ErrNotFound, }, { desc: "send verification email with failed to retrieve user verification", session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveUserVerError: svcerr.ErrViewEntity, err: svcerr.ErrViewEntity, }, { desc: "send verification email with failed to add user verification", session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveUserVerError: repoerr.ErrNotFound, addUserVerError: svcerr.ErrCreateEntity, err: svcerr.ErrCreateEntity, }, { desc: "send verification email with failed to send email", session: authn.Session{UserID: user.ID}, retrieveByIDResponse: user, retrieveUserVerError: repoerr.ErrNotFound, sendVerificationEmailError: svcerr.ErrCreateEntity, err: svcerr.ErrCreateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.retrieveByIDResponse, tc.retrieveByIDError) repoCall1 := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError) repoCall2 := cRepo.On("AddUserVerification", context.Background(), mock.Anything).Return(tc.addUserVerError) emailCall := e.On("SendVerification", []string{user.Email}, user.Credentials.Username, mock.Anything).Return(tc.sendVerificationEmailError) err := svc.SendVerification(context.Background(), tc.session) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() repoCall1.Unset() repoCall2.Unset() emailCall.Unset() }) } } func TestVerifyEmail(t *testing.T) { //nolint:dogsled svc, _, cRepo, _, _ := newService() uv, err := users.NewUserVerification(user.ID, user.Email) assert.Nil(t, err, fmt.Sprintf("failed to create user verification: %v", err)) uvs, err := uv.Encode() assert.Nil(t, err, fmt.Sprintf("failed to encode user verification: %v", err)) createdAt := time.Now().Add(-5 * users.VerificationExpiryDuration).UTC() expiresdAt := time.Now().Add(-users.VerificationExpiryDuration).UTC() cases := []struct { desc string uvs string retrieveUserVerResponse users.UserVerification retrieveUserVerError error updateUserVerError error updateVerifiedAtError error err error }{ { desc: "verify email successfully", uvs: uvs, retrieveUserVerResponse: uv, err: nil, }, { desc: "verify email with malformed token", uvs: "invalid", err: svcerr.ErrInvalidUserVerification, }, { desc: "verify email with non-existing user verification", uvs: uvs, retrieveUserVerError: repoerr.ErrNotFound, err: svcerr.ErrViewEntity, }, { desc: "verify email with expired token", uvs: uvs, retrieveUserVerResponse: users.UserVerification{ UserID: uv.UserID, Email: uv.Email, OTP: uv.OTP, ExpiresAt: expiresdAt, CreatedAt: createdAt, UsedAt: uv.UsedAt, }, err: svcerr.ErrUserVerificationExpired, }, { desc: "verify email with failed to update user verification", uvs: uvs, retrieveUserVerResponse: uv, updateUserVerError: svcerr.ErrUpdateEntity, err: svcerr.ErrUpdateEntity, }, { desc: "verify email with failed to update verified at", uvs: uvs, retrieveUserVerResponse: uv, updateVerifiedAtError: svcerr.ErrUpdateEntity, err: svcerr.ErrUpdateEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError) repoCall1 := cRepo.On("UpdateUserVerification", context.Background(), mock.Anything).Return(tc.updateUserVerError) repoCall2 := cRepo.On("UpdateVerifiedAt", context.Background(), mock.Anything).Return(users.User{}, tc.updateVerifiedAtError) _, err := svc.VerifyEmail(context.Background(), tc.uvs) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() repoCall1.Unset() repoCall2.Unset() }) } }