MG-2007 - Rollback Policies if Entity Creation Fails (#2255)

Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com>
Signed-off-by: Rodney Osodo <socials@rodneyosodo.com>
This commit is contained in:
b1ackd0t
2024-06-12 11:27:44 +03:00
committed by GitHub
parent 20f8c084c1
commit ce02e30587
11 changed files with 279 additions and 125 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ func (repo domainRepo) Save(ctx context.Context, d auth.Domain) (ad auth.Domain,
dbd, err := toDBDomain(d)
if err != nil {
return auth.Domain{}, errors.Wrap(repoerr.ErrCreateEntity, repoerr.ErrRollbackTx)
return auth.Domain{}, errors.Wrap(repoerr.ErrCreateEntity, errors.ErrRollbackTx)
}
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
+87 -36
View File
@@ -66,46 +66,23 @@ func (svc service) CreateGroup(ctx context.Context, token, kind string, g groups
}
}
g, err = svc.groups.Save(ctx, g)
if err := svc.addGroupPolicy(ctx, res.GetId(), res.GetDomainId(), g.ID, g.Parent, kind); err != nil {
return groups.Group{}, err
}
defer func() {
if err != nil {
if errRollback := svc.addGroupPolicyRollback(ctx, res.GetId(), res.GetDomainId(), g.ID, g.Parent, kind); errRollback != nil {
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
}
}
}()
saved, err := svc.groups.Save(ctx, g)
if err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
// IMPROVEMENT NOTE: Add defer function , if return err is not nil, then delete group
policies := magistrala.AddPoliciesReq{}
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: res.GetDomainId(),
SubjectType: auth.UserType,
Subject: res.GetId(),
Relation: auth.AdministratorRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: g.ID,
})
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: res.GetDomainId(),
SubjectType: auth.DomainType,
Subject: res.GetDomainId(),
Relation: auth.DomainRelation,
ObjectType: auth.GroupType,
Object: g.ID,
})
if g.Parent != "" {
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: res.GetDomainId(),
SubjectType: auth.GroupType,
Subject: g.Parent,
Relation: auth.ParentGroupRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: g.ID,
})
}
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
return g, errors.Wrap(svcerr.ErrAddPolicies, err)
}
return g, nil
return saved, nil
}
func (svc service) ViewGroup(ctx context.Context, token, id string) (groups.Group, error) {
@@ -753,3 +730,77 @@ func (svc service) authorizeKind(ctx context.Context, domainID, subjectType, sub
}
return res.GetId(), nil
}
func (svc service) addGroupPolicy(ctx context.Context, userID, domainID, id, parentID, kind string) error {
policies := magistrala.AddPoliciesReq{}
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: domainID,
SubjectType: auth.UserType,
Subject: userID,
Relation: auth.AdministratorRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: id,
})
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: domainID,
SubjectType: auth.DomainType,
Subject: domainID,
Relation: auth.DomainRelation,
ObjectType: auth.GroupType,
Object: id,
})
if parentID != "" {
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: domainID,
SubjectType: auth.GroupType,
Subject: parentID,
Relation: auth.ParentGroupRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: id,
})
}
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
return errors.Wrap(svcerr.ErrAddPolicies, err)
}
return nil
}
func (svc service) addGroupPolicyRollback(ctx context.Context, userID, domainID, id, parentID, kind string) error {
policies := magistrala.DeletePoliciesReq{}
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
Domain: domainID,
SubjectType: auth.UserType,
Subject: userID,
Relation: auth.AdministratorRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: id,
})
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
Domain: domainID,
SubjectType: auth.DomainType,
Subject: domainID,
Relation: auth.DomainRelation,
ObjectType: auth.GroupType,
Object: id,
})
if parentID != "" {
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
Domain: domainID,
SubjectType: auth.GroupType,
Subject: parentID,
Relation: auth.ParentGroupRelation,
ObjectKind: kind,
ObjectType: auth.GroupType,
Object: id,
})
}
if _, err := svc.auth.DeletePolicies(ctx, &policies); err != nil {
return errors.Wrap(svcerr.ErrDeletePolicies, err)
}
return nil
}
+52 -23
View File
@@ -52,21 +52,23 @@ func TestCreateGroup(t *testing.T) {
svc := groups.NewService(repo, idProvider, authsvc)
cases := []struct {
desc string
token string
kind string
group mggroups.Group
idResp *magistrala.IdentityRes
idErr error
authzResp *magistrala.AuthorizeRes
authzErr error
authzTknResp *magistrala.AuthorizeRes
authzTknErr error
repoResp mggroups.Group
repoErr error
addPolResp *magistrala.AddPoliciesRes
addPolErr error
err error
desc string
token string
kind string
group mggroups.Group
idResp *magistrala.IdentityRes
idErr error
authzResp *magistrala.AuthorizeRes
authzErr error
authzTknResp *magistrala.AuthorizeRes
authzTknErr error
repoResp mggroups.Group
repoErr error
addPolResp *magistrala.AddPoliciesRes
addPolErr error
deletePolResp *magistrala.DeletePoliciesRes
deletePolErr error
err error
}{
{
desc: "successfully",
@@ -256,12 +258,37 @@ func TestCreateGroup(t *testing.T) {
addPolErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "with failed to delete policies response",
token: token,
kind: auth.NewGroupKind,
group: mggroups.Group{
Name: namegen.Generate(),
Description: namegen.Generate(),
Status: clients.Status(groups.EnabledStatus),
Parent: testsutil.GenerateUUID(t),
},
idResp: &magistrala.IdentityRes{
Id: testsutil.GenerateUUID(t),
DomainId: testsutil.GenerateUUID(t),
},
authzResp: &magistrala.AuthorizeRes{
Authorized: true,
},
authzTknResp: &magistrala.AuthorizeRes{
Authorized: true,
},
repoErr: errors.ErrMalformedEntity,
addPolResp: &magistrala.AddPoliciesRes{Added: true},
deletePolErr: svcerr.ErrAuthorization,
err: errors.ErrMalformedEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall := authsvc.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.idResp, tc.idErr)
repocall1 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
authCall := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
SubjectType: auth.UserType,
SubjectKind: auth.UsersKind,
Subject: tc.idResp.GetId(),
@@ -269,7 +296,7 @@ func TestCreateGroup(t *testing.T) {
Object: tc.idResp.GetDomainId(),
ObjectType: auth.DomainType,
}).Return(tc.authzResp, tc.authzErr)
repocall2 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
authCall1 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
SubjectType: auth.UserType,
SubjectKind: auth.TokenKind,
Subject: tc.token,
@@ -277,8 +304,9 @@ func TestCreateGroup(t *testing.T) {
Object: tc.group.Parent,
ObjectType: auth.GroupType,
}).Return(tc.authzTknResp, tc.authzTknErr)
repocall3 := repo.On("Save", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
repocall4 := authsvc.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPolResp, tc.addPolErr)
repocall1 := repo.On("Save", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
authCall2 := authsvc.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPolResp, tc.addPolErr)
authCall3 := authsvc.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePolResp, tc.deletePolErr)
got, err := svc.CreateGroup(context.Background(), tc.token, tc.kind, tc.group)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
if err == nil {
@@ -286,14 +314,15 @@ func TestCreateGroup(t *testing.T) {
assert.NotEmpty(t, got.CreatedAt)
assert.NotEmpty(t, got.Domain)
assert.WithinDuration(t, time.Now(), got.CreatedAt, 2*time.Second)
ok := repocall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
ok := repocall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repocall.Unset()
authCall.Unset()
authCall1.Unset()
repocall1.Unset()
repocall2.Unset()
repocall3.Unset()
repocall4.Unset()
authCall2.Unset()
authCall3.Unset()
})
}
}
-3
View File
@@ -31,9 +31,6 @@ var (
// ErrFailedOpDB indicates a failure in a database operation.
ErrFailedOpDB = errors.New("operation on db element failed")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = errors.New("failed to rollback transaction")
// ErrFailedToRetrieveAllGroups failed to retrieve groups.
ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups")
)
+3
View File
@@ -20,4 +20,7 @@ var (
// ErrStatusAlreadyAssigned indicated that the client or group has already been assigned the status.
ErrStatusAlreadyAssigned = errors.New("status already assigned")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = errors.New("failed to rollback transaction")
)
+10 -8
View File
@@ -137,21 +137,23 @@ func TestCreateChannel(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err)
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
repoCall := grepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err)
rChannel, err := mgsdk.CreateChannel(tc.channel, validToken)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if err == nil {
assert.NotEmpty(t, rChannel, fmt.Sprintf("%s: expected not nil on client ID", tc.desc))
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
authCall.Unset()
authCall1.Unset()
authCall2.Unset()
authCall3.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}
+10 -8
View File
@@ -138,21 +138,23 @@ func TestCreateGroup(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err)
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
repoCall := grepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err)
rGroup, err := mgsdk.CreateGroup(tc.group, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if err == nil {
assert.NotEmpty(t, rGroup, fmt.Sprintf("%s: expected not nil on client ID", tc.desc))
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
authCall.Unset()
authCall1.Unset()
authCall2.Unset()
authCall3.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}
+20 -16
View File
@@ -183,9 +183,10 @@ func TestCreateThing(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
repoCall3 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response), tc.repoErr)
rThing, err := mgsdk.CreateThing(tc.client, tc.token)
@@ -200,9 +201,10 @@ func TestCreateThing(t *testing.T) {
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
authCall.Unset()
authCall1.Unset()
authCall2.Unset()
authCall3.Unset()
repoCall3.Unset()
}
}
@@ -270,12 +272,13 @@ func TestCreateThings(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall3 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
repoCall1 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
if len(tc.things) > 0 {
repoCall3 = cRepo.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
repoCall1 = cRepo.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
}
rThing, err := mgsdk.CreateThings(tc.things, tc.token)
for i, t := range rThing {
@@ -290,17 +293,18 @@ func TestCreateThings(t *testing.T) {
if tc.err == nil {
switch len(tc.things) {
case 1:
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
case 2:
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything, mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
}
repoCall.Unset()
authCall.Unset()
authCall1.Unset()
authCall2.Unset()
authCall3.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
}
}
+69 -24
View File
@@ -93,35 +93,22 @@ func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclie
clients = append(clients, c)
}
if err := svc.addThingPolicies(ctx, user.GetId(), user.GetDomainId(), clients); err != nil {
return []mgclients.Client{}, err
}
defer func() {
if err != nil {
if errRollback := svc.addThingPoliciesRollback(ctx, user.GetId(), user.GetDomainId(), clients); errRollback != nil {
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
}
}
}()
saved, err := svc.clients.Save(ctx, clients...)
if err != nil {
return nil, errors.Wrap(svcerr.ErrCreateEntity, err)
}
policies := magistrala.AddPoliciesReq{}
for _, c := range saved {
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: user.GetDomainId(),
SubjectType: auth.UserType,
Subject: user.GetId(),
Relation: auth.AdministratorRelation,
ObjectKind: auth.NewThingKind,
ObjectType: auth.ThingType,
Object: c.ID,
})
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: user.GetDomainId(),
SubjectType: auth.DomainType,
Subject: user.GetDomainId(),
Relation: auth.DomainRelation,
ObjectType: auth.ThingType,
Object: c.ID,
})
}
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
return nil, errors.Wrap(svcerr.ErrCreateEntity, err)
}
return saved, nil
}
@@ -609,3 +596,61 @@ func (svc *service) authorize(ctx context.Context, domainID, subjType, subjKind,
return res.GetId(), nil
}
func (svc service) addThingPolicies(ctx context.Context, userID, domainID string, things []mgclients.Client) error {
policies := magistrala.AddPoliciesReq{}
for _, thing := range things {
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: domainID,
SubjectType: auth.UserType,
Subject: userID,
Relation: auth.AdministratorRelation,
ObjectKind: auth.NewThingKind,
ObjectType: auth.ThingType,
Object: thing.ID,
})
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
Domain: domainID,
SubjectType: auth.DomainType,
Subject: domainID,
Relation: auth.DomainRelation,
ObjectType: auth.ThingType,
Object: thing.ID,
})
}
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
return nil
}
func (svc service) addThingPoliciesRollback(ctx context.Context, userID, domainID string, things []mgclients.Client) error {
policies := magistrala.DeletePoliciesReq{}
for _, thing := range things {
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
Domain: domainID,
SubjectType: auth.UserType,
Subject: userID,
Relation: auth.AdministratorRelation,
ObjectKind: auth.NewThingKind,
ObjectType: auth.ThingType,
Object: thing.ID,
})
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
Domain: domainID,
SubjectType: auth.DomainType,
Subject: domainID,
Relation: auth.DomainRelation,
ObjectType: auth.ThingType,
Object: thing.ID,
})
}
if _, err := svc.auth.DeletePolicies(ctx, &policies); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
+26 -5
View File
@@ -65,9 +65,11 @@ func TestCreateThings(t *testing.T) {
token string
authResponse *magistrala.AuthorizeRes
addPolicyResponse *magistrala.AddPoliciesRes
deletePolicyRes *magistrala.DeletePoliciesRes
authorizeErr error
identifyErr error
addPolicyErr error
deletePolicyErr error
saveErr error
err error
}{
@@ -305,13 +307,31 @@ func TestCreateThings(t *testing.T) {
addPolicyErr: svcerr.ErrInvalidPolicy,
err: svcerr.ErrInvalidPolicy,
},
{
desc: "create a new thing with failed delete policy response",
thing: mgclients.Client{
Credentials: mgclients.Credentials{
Identity: "newclientwithfailedpolicy@example.com",
Secret: secret,
},
Status: mgclients.EnabledStatus,
},
token: validToken,
authResponse: &magistrala.AuthorizeRes{Authorized: true},
addPolicyResponse: &magistrala.AddPoliciesRes{Added: true},
saveErr: repoerr.ErrConflict,
deletePolicyRes: &magistrala.DeletePoliciesRes{Deleted: false},
deletePolicyErr: svcerr.ErrInvalidPolicy,
err: repoerr.ErrConflict,
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, tc.identifyErr)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(tc.authResponse, tc.authorizeErr)
repoCall2 := cRepo.On("Save", context.Background(), mock.Anything).Return([]mgclients.Client{tc.thing}, tc.saveErr)
repoCall3 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyResponse, tc.addPolicyErr)
authcall := auth.On("Authorize", mock.Anything, mock.Anything).Return(tc.authResponse, tc.authorizeErr)
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return([]mgclients.Client{tc.thing}, tc.saveErr)
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyResponse, tc.addPolicyErr)
authCall2 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePolicyRes, tc.deletePolicyErr)
expected, err := svc.CreateThings(context.Background(), tc.token, tc.thing)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@@ -324,9 +344,10 @@ func TestCreateThings(t *testing.T) {
assert.Equal(t, tc.thing, expected[0], fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.thing, expected[0]))
}
repoCall.Unset()
authcall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
authCall1.Unset()
authCall2.Unset()
}
}
+1 -1
View File
@@ -84,7 +84,7 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
defer func() {
if err != nil {
if errRollback := svc.addClientPolicyRollback(ctx, cli.ID, cli.Role); errRollback != nil {
err = errors.Wrap(errors.Wrap(repoerr.ErrRollbackTx, errRollback), err)
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
}
}
}()