From dc72811048df060f876b9c7a556ec75561801c66 Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Thu, 16 Apr 2026 19:08:14 +0300 Subject: [PATCH] NOISSUE - Update superadmin check (#3394) Signed-off-by: nyagamunene --- channels/middleware/authorization.go | 6 +- domains/middleware/authorization.go | 12 +++- users/middleware/authorization.go | 67 ++++++++++++++++---- users/service_test.go | 92 ++++++++++++++-------------- 4 files changed, 117 insertions(+), 60 deletions(-) diff --git a/channels/middleware/authorization.go b/channels/middleware/authorization.go index 3a2f721d9..03860dfdc 100644 --- a/channels/middleware/authorization.go +++ b/channels/middleware/authorization.go @@ -122,8 +122,12 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth } func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return channels.ChannelsPage{}, err } return am.svc.ListChannels(ctx, session, pm) diff --git a/domains/middleware/authorization.go b/domains/middleware/authorization.go index 06ec7456d..98c4d5184 100644 --- a/domains/middleware/authorization.go +++ b/domains/middleware/authorization.go @@ -57,9 +57,13 @@ func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session aut } func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session authn.Session, id string, withRoles bool) (domains.Domain, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true return am.svc.RetrieveDomain(ctx, session, id, withRoles) + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return domains.Domain{}, err } if err := am.authorize(ctx, session, policies.DomainType, operations.OpRetrieveDomain, authz.PolicyReq{ @@ -134,8 +138,12 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut } func (am *authorizationMiddleware) ListDomains(ctx context.Context, session authn.Session, page domains.Page) (domains.DomainsPage, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return domains.DomainsPage{}, err } return am.svc.ListDomains(ctx, session, page) diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index 4eb5b353d..e68c3845f 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/magistrala/auth" "github.com/absmach/magistrala/pkg/authn" smqauthz "github.com/absmach/magistrala/pkg/authz" + "github.com/absmach/magistrala/pkg/errors" svcerr "github.com/absmach/magistrala/pkg/errors/service" "github.com/absmach/magistrala/pkg/policies" "github.com/absmach/magistrala/users" @@ -38,8 +39,12 @@ func (am *authorizationMiddleware) VerifyEmail(ctx context.Context, verification func (am *authorizationMiddleware) Register(ctx context.Context, session authn.Session, user users.User, selfRegister bool) (users.User, error) { if selfRegister { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } } @@ -47,8 +52,12 @@ func (am *authorizationMiddleware) Register(ctx context.Context, session authn.S } func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.View(ctx, session, id) @@ -59,8 +68,12 @@ func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session auth } func (am *authorizationMiddleware) ListUsers(ctx context.Context, session authn.Session, pm users.Page) (users.UsersPage, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.UsersPage{}, err } return am.svc.ListUsers(ctx, session, pm) @@ -71,40 +84,60 @@ func (am *authorizationMiddleware) SearchUsers(ctx context.Context, pm users.Pag } func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, id string, user users.UserReq) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.Update(ctx, session, id, user) } func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, id string, user users.UserReq) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.UpdateTags(ctx, session, id, user) } func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session authn.Session, id, email string) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.UpdateEmail(ctx, session, id, email) } func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session authn.Session, id, username string) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.UpdateUsername(ctx, session, id, username) } func (am *authorizationMiddleware) UpdateProfilePicture(ctx context.Context, session authn.Session, id string, usr users.UserReq) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.UpdateProfilePicture(ctx, session, id, usr) @@ -135,24 +168,36 @@ func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn } func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.Enable(ctx, session, id) } func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (users.User, error) { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return users.User{}, err } return am.svc.Disable(ctx, session, id) } func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { - if err := am.checkSuperAdmin(ctx, session); err == nil { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return err } return am.svc.Delete(ctx, session, id) diff --git a/users/service_test.go b/users/service_test.go index 482920ef4..3871fd74c 100644 --- a/users/service_test.go +++ b/users/service_test.go @@ -249,52 +249,52 @@ func TestRegister(t *testing.T) { svc, _, cRepo, policies, _ = newService() - // cases2 := []struct { - // desc string - // user users.User - // session authn.Session - // addPoliciesResponseErr error - // deletePoliciesResponseErr error - // saveErr error - // checkSuperAdminErr error - // err error - // }{ - // { - // desc: "register new user successfully as admin", - // user: user, - // session: authn.Session{UserID: validID, SuperAdmin: true}, - // err: nil, - // }, - // { - // desc: "register a new user as admin with failed check on super admin", - // user: user, - // session: authn.Session{UserID: validID, SuperAdmin: false}, - // checkSuperAdminErr: svcerr.ErrAuthorization, - // err: svcerr.ErrAuthorization, - // }, - // } - // for _, tc := range cases2 { - // repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) - // policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr) - // policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr) - // repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr) - // expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false) - // assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - // if err == nil { - // tc.user.ID = expected.ID - // tc.user.CreatedAt = expected.CreatedAt - // tc.user.UpdatedAt = expected.UpdatedAt - // tc.user.Credentials.Secret = expected.Credentials.Secret - // tc.user.UpdatedBy = expected.UpdatedBy - // assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected)) - // ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) - // assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - // } - // repoCall1.Unset() - // policyCall.Unset() - // policyCall1.Unset() - // repoCall.Unset() - // } + cases2 := []struct { + desc string + user users.User + session authn.Session + addPoliciesResponseErr error + deletePoliciesResponseErr error + saveErr error + checkSuperAdminErr error + err error + }{ + { + desc: "register new user successfully as admin", + user: user, + session: authn.Session{UserID: validID, SuperAdmin: true}, + err: nil, + }, + { + desc: "register a new user as admin with failed check on super admin", + user: user, + session: authn.Session{UserID: validID, SuperAdmin: false}, + checkSuperAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthorization, + }, + } + for _, tc := range cases2 { + repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) + policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr) + policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr) + repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr) + expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + tc.user.ID = expected.ID + tc.user.CreatedAt = expected.CreatedAt + tc.user.UpdatedAt = expected.UpdatedAt + tc.user.Credentials.Secret = expected.Credentials.Secret + tc.user.UpdatedBy = expected.UpdatedBy + assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected)) + ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) + } + repoCall1.Unset() + policyCall.Unset() + policyCall1.Unset() + repoCall.Unset() + } } func TestViewUser(t *testing.T) {