MG-2556 - Fix super admin functionality (#2558)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2024-11-29 11:17:43 +03:00
committed by Dusan Borovcanin
parent ba556e1e0d
commit 5b2c1bab8e
5 changed files with 460 additions and 175 deletions
+108 -92
View File
@@ -497,15 +497,17 @@ func (svc service) RetrieveDomain(ctx context.Context, token, id string) (Domain
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if err = svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, res.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
}); err != nil {
return Domain{ID: domain.ID, Name: domain.Name, Alias: domain.Alias}, nil
if err := svc.checkSuperAdmin(ctx, res.User); err != nil {
if err = svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, res.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
}); err != nil {
return Domain{ID: domain.ID, Name: domain.Name, Alias: domain.Alias}, nil
}
}
return domain, nil
}
@@ -515,21 +517,25 @@ func (svc service) RetrieveDomainPermissions(ctx context.Context, token, id stri
if err != nil {
return []string{}, err
}
domainUserSubject := EncodeDomainUserID(id, res.User)
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserSubject,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
}); err != nil {
return []string{}, err
subject := res.User
if err := svc.checkSuperAdmin(ctx, res.User); err != nil {
domainUserSubject := EncodeDomainUserID(id, res.User)
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserSubject,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
}); err != nil {
return []string{}, err
}
subject = domainUserSubject
}
lp, err := svc.policysvc.ListPermissions(ctx, policies.Policy{
SubjectType: policies.UserType,
Subject: domainUserSubject,
Subject: subject,
Object: id,
ObjectType: policies.DomainType,
}, []string{policies.AdminPermission, policies.EditPermission, policies.ViewPermission, policies.MembershipPermission, policies.CreatePermission})
@@ -544,15 +550,17 @@ func (svc service) UpdateDomain(ctx context.Context, token, id string, d DomainR
if err != nil {
return Domain{}, err
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, key.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.EditPermission,
}); err != nil {
return Domain{}, err
if err := svc.checkSuperAdmin(ctx, key.User); err != nil {
if err := svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, key.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.EditPermission,
}); err != nil {
return Domain{}, err
}
}
dom, err := svc.domains.Update(ctx, id, key.User, d)
@@ -567,15 +575,17 @@ func (svc service) ChangeDomainStatus(ctx context.Context, token, id string, d D
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, key.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
}); err != nil {
return Domain{}, err
if err := svc.checkSuperAdmin(ctx, key.User); err != nil {
if err := svc.Authorize(ctx, policies.Policy{
Subject: EncodeDomainUserID(id, key.User),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
}); err != nil {
return Domain{}, err
}
}
dom, err := svc.domains.Update(ctx, id, key.User, d)
@@ -591,13 +601,7 @@ func (svc service) ListDomains(ctx context.Context, token string, p Page) (Domai
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
p.SubjectID = key.User
if err := svc.Authorize(ctx, policies.Policy{
Subject: key.User,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.MagistralaObject,
}); err == nil {
if err := svc.checkSuperAdmin(ctx, key.User); err == nil {
p.SubjectID = ""
}
dp, err := svc.domains.ListDomains(ctx, p)
@@ -618,27 +622,29 @@ func (svc service) AssignUsers(ctx context.Context, token, id string, userIds []
return errors.Wrap(svcerr.ErrAuthentication, err)
}
domainUserID := EncodeDomainUserID(id, res.User)
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
}); err != nil {
return err
}
if err := svc.checkSuperAdmin(ctx, res.User); err != nil {
domainUserID := EncodeDomainUserID(id, res.User)
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
}); err != nil {
return err
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: SwitchToPermission(relation),
}); err != nil {
return err
if err := svc.Authorize(ctx, policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: SwitchToPermission(relation),
}); err != nil {
return err
}
}
for _, userID := range userIds {
@@ -662,27 +668,29 @@ func (svc service) UnassignUser(ctx context.Context, token, id, userID string) e
return errors.Wrap(svcerr.ErrAuthentication, err)
}
domainUserID := EncodeDomainUserID(id, res.User)
pr := policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
}
if err := svc.Authorize(ctx, pr); err != nil {
return err
}
if err := svc.checkSuperAdmin(ctx, res.User); err != nil {
domainUserID := EncodeDomainUserID(id, res.User)
pr := policies.Policy{
Subject: domainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
}
if err := svc.Authorize(ctx, pr); err != nil {
return err
}
pr.Permission = policies.AdminPermission
if err := svc.Authorize(ctx, pr); err != nil {
pr.SubjectKind = policies.UsersKind
// User is not admin.
pr.Subject = userID
if err := svc.Authorize(ctx, pr); err == nil {
// Non admin attempts to remove admin.
return errors.Wrap(svcerr.ErrAuthorization, err)
pr.Permission = policies.AdminPermission
if err := svc.Authorize(ctx, pr); err != nil {
pr.SubjectKind = policies.UsersKind
// User is not admin.
pr.Subject = userID
if err := svc.Authorize(ctx, pr); err == nil {
// Non admin attempts to remove admin.
return errors.Wrap(svcerr.ErrAuthorization, err)
}
}
}
@@ -713,13 +721,7 @@ func (svc service) ListUserDomains(ctx context.Context, token, userID string, p
if err != nil {
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: res.User,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}); err != nil {
if err := svc.checkSuperAdmin(ctx, res.User); err != nil {
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
if userID != "" && res.User != userID {
@@ -906,3 +908,17 @@ func (svc service) DeleteUserFromDomains(ctx context.Context, id string) (err er
return nil
}
func (svc service) checkSuperAdmin(ctx context.Context, userID string) error {
if err := svc.evaluator.CheckPolicy(ctx, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}); err != nil {
return svcerr.ErrAuthorization
}
return nil
}
+325 -58
View File
@@ -1396,15 +1396,23 @@ func TestRetrieveDomain(t *testing.T) {
domainID string
domainRepoErr error
domainRepoErr1 error
checkAdminErr error
checkPolicyErr error
err error
}{
{
desc: "retrieve domain successfully",
desc: "retrieve domain successfully as super admin",
token: accessToken,
domainID: validID,
err: nil,
},
{
desc: "retrieve domain successfully as domain admin",
token: accessToken,
domainID: validID,
checkAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "retrieve domain with invalid token",
token: inValidToken,
@@ -1438,13 +1446,44 @@ func TestRetrieveDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, groupName).Return(auth.Domain{}, tc.domainRepoErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkAdminErr)
policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("RetrieveByID", mock.Anything, tc.domainID).Return(auth.Domain{}, tc.domainRepoErr1)
_, err := svc.RetrieveDomain(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
policyCall.Unset()
policyCall1.Unset()
policyCall2.Unset()
policyCall3.Unset()
})
}
}
@@ -1458,15 +1497,23 @@ func TestRetrieveDomainPermissions(t *testing.T) {
domainID string
retreivePermissionsErr error
retreiveByIDErr error
checkAdminErr error
checkPolicyErr error
err error
}{
{
desc: "retrieve domain permissions successfully",
desc: "retrieve domain permissions successfully as platform admin",
token: accessToken,
domainID: validID,
err: nil,
},
{
desc: "retrieve domain permissions successfully as domain admin",
token: accessToken,
domainID: validID,
checkAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "retrieve domain permissions with invalid token",
token: inValidToken,
@@ -1474,11 +1521,11 @@ func TestRetrieveDomainPermissions(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "retrieve domain permissions with empty domainID",
token: accessToken,
domainID: "",
checkPolicyErr: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
desc: "retrieve domain permissions with empty domainID",
token: accessToken,
domainID: "",
retreivePermissionsErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "retrieve domain permissions with failed to retrieve permissions",
@@ -1491,6 +1538,7 @@ func TestRetrieveDomainPermissions(t *testing.T) {
desc: "retrieve domain permissions with failed to retrieve by id",
token: accessToken,
domainID: validID,
checkAdminErr: svcerr.ErrAuthorization,
retreiveByIDErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -1500,12 +1548,43 @@ func TestRetrieveDomainPermissions(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
repoCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(policies.Permissions{}, tc.retreivePermissionsErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreiveByIDErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkAdminErr)
policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
_, err := svc.RetrieveDomainPermissions(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
policyCall.Unset()
policyCall1.Unset()
policyCall2.Unset()
policyCall3.Unset()
})
}
}
@@ -1521,10 +1600,11 @@ func TestUpdateDomain(t *testing.T) {
checkPolicyErr error
retrieveByIDErr error
updateErr error
checkAdminErr error
err error
}{
{
desc: "update domain successfully",
desc: "update domain successfully as platform admin",
token: accessToken,
domainID: validID,
domReq: auth.DomainReq{
@@ -1533,6 +1613,17 @@ func TestUpdateDomain(t *testing.T) {
},
err: nil,
},
{
desc: "update domain successfully as domain admin",
token: accessToken,
domainID: validID,
domReq: auth.DomainReq{
Name: &valid,
Alias: &valid,
},
checkAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "update domain with invalid token",
token: inValidToken,
@@ -1552,6 +1643,7 @@ func TestUpdateDomain(t *testing.T) {
Alias: &valid,
},
checkPolicyErr: svcerr.ErrAuthorization,
checkAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
},
{
@@ -1563,6 +1655,7 @@ func TestUpdateDomain(t *testing.T) {
Alias: &valid,
},
retrieveByIDErr: repoerr.ErrNotFound,
checkAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrNotFound,
},
{
@@ -1580,14 +1673,45 @@ func TestUpdateDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkAdminErr)
policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.EditPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.UpdateDomain(context.Background(), tc.token, tc.domainID, tc.domReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
policyCall.Unset()
policyCall1.Unset()
policyCall2.Unset()
policyCall3.Unset()
})
}
}
@@ -1604,11 +1728,12 @@ func TestChangeDomainStatus(t *testing.T) {
domainReq auth.DomainReq
retreieveByIDErr error
checkPolicyErr error
checkAdminErr error
updateErr error
err error
}{
{
desc: "change domain status successfully",
desc: "change domain status successfully as platform admin",
token: accessToken,
domainID: validID,
domainReq: auth.DomainReq{
@@ -1616,6 +1741,16 @@ func TestChangeDomainStatus(t *testing.T) {
},
err: nil,
},
{
desc: "change domain status successfully as platform admin",
token: accessToken,
domainID: validID,
domainReq: auth.DomainReq{
Status: &disabledStatus,
},
checkAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "change domain status with invalid token",
token: inValidToken,
@@ -1632,6 +1767,7 @@ func TestChangeDomainStatus(t *testing.T) {
domainReq: auth.DomainReq{
Status: &disabledStatus,
},
checkAdminErr: svcerr.ErrAuthorization,
retreieveByIDErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
@@ -1642,6 +1778,7 @@ func TestChangeDomainStatus(t *testing.T) {
domainReq: auth.DomainReq{
Status: &disabledStatus,
},
checkAdminErr: svcerr.ErrAuthorization,
checkPolicyErr: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
},
@@ -1660,13 +1797,44 @@ func TestChangeDomainStatus(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreieveByIDErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkAdminErr)
policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
Permission: policies.MembershipPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: auth.EncodeDomainUserID(tc.domainID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.AdminPermission,
Object: tc.domainID,
ObjectType: policies.DomainType,
}).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.ChangeDomainStatus(context.Background(), tc.token, tc.domainID, tc.domainReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
policyCall.Unset()
policyCall1.Unset()
policyCall2.Unset()
policyCall3.Unset()
})
}
}
@@ -1743,25 +1911,26 @@ func TestAssignUsers(t *testing.T) {
svc, accessToken := newService()
cases := []struct {
desc string
token string
domainID string
userIDs []string
relation string
checkPolicyReq policies.Policy
checkAdminPolicyReq policies.Policy
checkDomainPolicyReq policies.Policy
checkPolicyReq1 policies.Policy
checkpolicyErr error
checkPolicyErr1 error
checkPolicyErr2 error
addPoliciesErr error
savePoliciesErr error
deletePoliciesErr error
err error
desc string
token string
domainID string
userIDs []string
relation string
checkPolicyReq policies.Policy
checkAdminPolicyReq policies.Policy
checkDomainPolicyReq policies.Policy
checkPolicyReq1 policies.Policy
checkpolicyErr error
checkPolicyErr1 error
checkPolicyErr2 error
addPoliciesErr error
savePoliciesErr error
deletePoliciesErr error
checkPlatformAdminErr error
err error
}{
{
desc: "assign users successfully",
desc: "assign users successfully as platform admin",
token: accessToken,
domainID: validID,
userIDs: []string{validID},
@@ -1798,6 +1967,45 @@ func TestAssignUsers(t *testing.T) {
},
err: nil,
},
{
desc: "assign users successfully",
token: accessToken,
domainID: validID,
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
},
checkDomainPolicyReq: policies.Policy{
Subject: validID,
SubjectType: policies.UserType,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
Permission: policies.MembershipPermission,
},
checkPolicyReq1: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "assign users with invalid token",
token: inValidToken,
@@ -1828,7 +2036,8 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.PlatformType,
Permission: policies.MembershipPermission,
},
err: svcerr.ErrAuthentication,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthentication,
},
{
desc: "assign users with invalid domainID",
@@ -1858,8 +2067,9 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkPolicyErr1: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
checkPolicyErr1: svcerr.ErrAuthorization,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "assign users with invalid userIDs",
@@ -1897,8 +2107,9 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkPolicyErr2: svcerr.ErrMalformedEntity,
err: svcerr.ErrDomainAuthorization,
checkPolicyErr2: svcerr.ErrMalformedEntity,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
},
{
desc: "assign users with failed to add policies to agent",
@@ -1936,8 +2147,9 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
addPoliciesErr: svcerr.ErrAuthorization,
err: errAddPolicies,
addPoliciesErr: svcerr.ErrAuthorization,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: errAddPolicies,
},
{
desc: "assign users with failed to save policies to domain",
@@ -1975,8 +2187,9 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
savePoliciesErr: repoerr.ErrCreateEntity,
err: errAddPolicies,
checkPlatformAdminErr: svcerr.ErrAuthorization,
savePoliciesErr: repoerr.ErrCreateEntity,
err: errAddPolicies,
},
{
desc: "assign users with failed to save policies to domain and failed to delete",
@@ -2014,15 +2227,23 @@ func TestAssignUsers(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
savePoliciesErr: repoerr.ErrCreateEntity,
deletePoliciesErr: svcerr.ErrDomainAuthorization,
err: errAddPolicies,
savePoliciesErr: repoerr.ErrCreateEntity,
deletePoliciesErr: svcerr.ErrDomainAuthorization,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: errAddPolicies,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkPlatformAdminErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkpolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr2)
@@ -2040,6 +2261,7 @@ func TestAssignUsers(t *testing.T) {
repoCall5.Unset()
repoCall6.Unset()
repoCall7.Unset()
policyCall.Unset()
})
}
}
@@ -2059,10 +2281,11 @@ func TestUnassignUser(t *testing.T) {
checkPolicyErr1 error
deletePolicyFilterErr error
deletePoliciesErr error
checkPlatformAdminErr error
err error
}{
{
desc: "unassign user successfully",
desc: "unassign user successfully as platform admin",
token: accessToken,
domainID: validID,
userID: validID,
@@ -2091,6 +2314,37 @@ func TestUnassignUser(t *testing.T) {
},
err: nil,
},
{
desc: "unassign user successfully as domain admin",
token: accessToken,
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkAdminPolicyReq: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Subject: auth.EncodeDomainUserID(validID, userID),
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: nil,
},
{
desc: "unassign users with invalid token",
token: inValidToken,
@@ -2112,7 +2366,8 @@ func TestUnassignUser(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
err: svcerr.ErrAuthentication,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthentication,
},
{
desc: "unassign users with invalid domainID",
@@ -2142,8 +2397,9 @@ func TestUnassignUser(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkPolicyErr1: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
checkPolicyErr1: svcerr.ErrAuthorization,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrDomainAuthorization,
},
{
desc: "unassign users with failed to delete policies from agent",
@@ -2174,6 +2430,7 @@ func TestUnassignUser(t *testing.T) {
Permission: policies.MembershipPermission,
},
deletePolicyFilterErr: errors.ErrMalformedEntity,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: errors.ErrMalformedEntity,
},
{
@@ -2206,6 +2463,7 @@ func TestUnassignUser(t *testing.T) {
},
deletePoliciesErr: errors.ErrMalformedEntity,
deletePolicyFilterErr: errors.ErrMalformedEntity,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: errors.ErrMalformedEntity,
},
{
@@ -2236,26 +2494,35 @@ func TestUnassignUser(t *testing.T) {
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
deletePoliciesErr: errors.ErrMalformedEntity,
err: errors.ErrMalformedEntity,
deletePoliciesErr: errors.ErrMalformedEntity,
checkPlatformAdminErr: svcerr.ErrAuthorization,
err: errors.ErrMalformedEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: userID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}).Return(tc.checkPlatformAdminErr)
policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr)
policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
repoCall4 := pService.On("DeletePolicyFilter", mock.Anything, mock.Anything).Return(tc.deletePolicyFilterErr)
repoCall5 := drepo.On("DeletePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
err := svc.UnassignUser(context.Background(), tc.token, tc.domainID, tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
policyCall.Unset()
policyCall1.Unset()
policyCall2.Unset()
repoCall4.Unset()
policyCall3.Unset()
repoCall5.Unset()
})
}
+7 -10
View File
@@ -30,9 +30,6 @@ func AuthorizationMiddleware(authz authz.Authorization, svc invitations.Service)
}
func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) {
if err := am.checkAdmin(ctx, session.UserID, session.DomainID); err != nil {
return err
}
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.UserID)
if err := am.authorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil {
@@ -40,7 +37,7 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a
return errors.Wrap(svcerr.ErrConflict, ErrMemberExist)
}
if err := am.checkAdmin(ctx, session.DomainUserID, invitation.DomainID); err != nil {
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
@@ -50,7 +47,7 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a
func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if session.UserID != userID {
if err := am.checkAdmin(ctx, session.DomainUserID, domain); err != nil {
if err := am.checkAdmin(ctx, session); err != nil {
return invitations.Invitation{}, err
}
}
@@ -60,7 +57,7 @@ func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session a
func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil {
if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil {
session.SuperAdmin = true
}
@@ -88,7 +85,7 @@ func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session
func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.checkAdmin(ctx, session.DomainUserID, domainID); err != nil {
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
@@ -96,12 +93,12 @@ func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session
}
// checkAdmin checks if the given user is a domain or platform administrator.
func (am *authorizationMiddleware) checkAdmin(ctx context.Context, userID, domainID string) error {
if err := am.authorize(ctx, userID, policies.AdminPermission, policies.DomainType, domainID); err == nil {
func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error {
if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID); err == nil {
return nil
}
if err := am.authorize(ctx, userID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil {
if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil {
return nil
}
+4 -1
View File
@@ -49,7 +49,10 @@ func (svc *service) SendInvitation(ctx context.Context, session authn.Session, i
invitation.CreatedAt = time.Now()
return svc.repo.Create(ctx, invitation)
if err := svc.repo.Create(ctx, invitation); err != nil {
return err
}
return nil
}
func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error) {
+16 -14
View File
@@ -67,21 +67,23 @@ func (am *authorizationMiddleware) ListMembers(ctx context.Context, session auth
if session.DomainUserID == "" {
return users.MembersPage{}, svcerr.ErrDomainAuthorization
}
switch objectKind {
case policies.GroupsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil {
return users.MembersPage{}, err
if err := am.checkSuperAdmin(ctx, session.UserID); err != nil {
switch objectKind {
case policies.GroupsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil {
return users.MembersPage{}, err
}
case policies.DomainsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil {
return users.MembersPage{}, err
}
case policies.ThingsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil {
return users.MembersPage{}, err
}
default:
return users.MembersPage{}, svcerr.ErrAuthorization
}
case policies.DomainsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil {
return users.MembersPage{}, err
}
case policies.ThingsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil {
return users.MembersPage{}, err
}
default:
return users.MembersPage{}, svcerr.ErrAuthorization
}
return am.svc.ListMembers(ctx, session, objectKind, objectID, pm)