mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:40:19 +00:00
MG-2007 - Rollback Policies if Entity Creation Fails (#2255)
Signed-off-by: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> Signed-off-by: Rodney Osodo <socials@rodneyosodo.com>
This commit is contained in:
@@ -42,7 +42,7 @@ func (repo domainRepo) Save(ctx context.Context, d auth.Domain) (ad auth.Domain,
|
||||
|
||||
dbd, err := toDBDomain(d)
|
||||
if err != nil {
|
||||
return auth.Domain{}, errors.Wrap(repoerr.ErrCreateEntity, repoerr.ErrRollbackTx)
|
||||
return auth.Domain{}, errors.Wrap(repoerr.ErrCreateEntity, errors.ErrRollbackTx)
|
||||
}
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
|
||||
|
||||
+87
-36
@@ -66,46 +66,23 @@ func (svc service) CreateGroup(ctx context.Context, token, kind string, g groups
|
||||
}
|
||||
}
|
||||
|
||||
g, err = svc.groups.Save(ctx, g)
|
||||
if err := svc.addGroupPolicy(ctx, res.GetId(), res.GetDomainId(), g.ID, g.Parent, kind); err != nil {
|
||||
return groups.Group{}, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if errRollback := svc.addGroupPolicyRollback(ctx, res.GetId(), res.GetDomainId(), g.ID, g.Parent, kind); errRollback != nil {
|
||||
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
saved, err := svc.groups.Save(ctx, g)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
// IMPROVEMENT NOTE: Add defer function , if return err is not nil, then delete group
|
||||
|
||||
policies := magistrala.AddPoliciesReq{}
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: res.GetDomainId(),
|
||||
SubjectType: auth.UserType,
|
||||
Subject: res.GetId(),
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: g.ID,
|
||||
})
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: res.GetDomainId(),
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: res.GetDomainId(),
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: g.ID,
|
||||
})
|
||||
if g.Parent != "" {
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: res.GetDomainId(),
|
||||
SubjectType: auth.GroupType,
|
||||
Subject: g.Parent,
|
||||
Relation: auth.ParentGroupRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: g.ID,
|
||||
})
|
||||
}
|
||||
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
|
||||
return g, errors.Wrap(svcerr.ErrAddPolicies, err)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
return saved, nil
|
||||
}
|
||||
|
||||
func (svc service) ViewGroup(ctx context.Context, token, id string) (groups.Group, error) {
|
||||
@@ -753,3 +730,77 @@ func (svc service) authorizeKind(ctx context.Context, domainID, subjectType, sub
|
||||
}
|
||||
return res.GetId(), nil
|
||||
}
|
||||
|
||||
func (svc service) addGroupPolicy(ctx context.Context, userID, domainID, id, parentID, kind string) error {
|
||||
policies := magistrala.AddPoliciesReq{}
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: domainID,
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
if parentID != "" {
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.GroupType,
|
||||
Subject: parentID,
|
||||
Relation: auth.ParentGroupRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
}
|
||||
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
|
||||
return errors.Wrap(svcerr.ErrAddPolicies, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) addGroupPolicyRollback(ctx context.Context, userID, domainID, id, parentID, kind string) error {
|
||||
policies := magistrala.DeletePoliciesReq{}
|
||||
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: domainID,
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
if parentID != "" {
|
||||
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.GroupType,
|
||||
Subject: parentID,
|
||||
Relation: auth.ParentGroupRelation,
|
||||
ObjectKind: kind,
|
||||
ObjectType: auth.GroupType,
|
||||
Object: id,
|
||||
})
|
||||
}
|
||||
if _, err := svc.auth.DeletePolicies(ctx, &policies); err != nil {
|
||||
return errors.Wrap(svcerr.ErrDeletePolicies, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,21 +52,23 @@ func TestCreateGroup(t *testing.T) {
|
||||
svc := groups.NewService(repo, idProvider, authsvc)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
kind string
|
||||
group mggroups.Group
|
||||
idResp *magistrala.IdentityRes
|
||||
idErr error
|
||||
authzResp *magistrala.AuthorizeRes
|
||||
authzErr error
|
||||
authzTknResp *magistrala.AuthorizeRes
|
||||
authzTknErr error
|
||||
repoResp mggroups.Group
|
||||
repoErr error
|
||||
addPolResp *magistrala.AddPoliciesRes
|
||||
addPolErr error
|
||||
err error
|
||||
desc string
|
||||
token string
|
||||
kind string
|
||||
group mggroups.Group
|
||||
idResp *magistrala.IdentityRes
|
||||
idErr error
|
||||
authzResp *magistrala.AuthorizeRes
|
||||
authzErr error
|
||||
authzTknResp *magistrala.AuthorizeRes
|
||||
authzTknErr error
|
||||
repoResp mggroups.Group
|
||||
repoErr error
|
||||
addPolResp *magistrala.AddPoliciesRes
|
||||
addPolErr error
|
||||
deletePolResp *magistrala.DeletePoliciesRes
|
||||
deletePolErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "successfully",
|
||||
@@ -256,12 +258,37 @@ func TestCreateGroup(t *testing.T) {
|
||||
addPolErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "with failed to delete policies response",
|
||||
token: token,
|
||||
kind: auth.NewGroupKind,
|
||||
group: mggroups.Group{
|
||||
Name: namegen.Generate(),
|
||||
Description: namegen.Generate(),
|
||||
Status: clients.Status(groups.EnabledStatus),
|
||||
Parent: testsutil.GenerateUUID(t),
|
||||
},
|
||||
idResp: &magistrala.IdentityRes{
|
||||
Id: testsutil.GenerateUUID(t),
|
||||
DomainId: testsutil.GenerateUUID(t),
|
||||
},
|
||||
authzResp: &magistrala.AuthorizeRes{
|
||||
Authorized: true,
|
||||
},
|
||||
authzTknResp: &magistrala.AuthorizeRes{
|
||||
Authorized: true,
|
||||
},
|
||||
repoErr: errors.ErrMalformedEntity,
|
||||
addPolResp: &magistrala.AddPoliciesRes{Added: true},
|
||||
deletePolErr: svcerr.ErrAuthorization,
|
||||
err: errors.ErrMalformedEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repocall := authsvc.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.idResp, tc.idErr)
|
||||
repocall1 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
|
||||
authCall := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
|
||||
SubjectType: auth.UserType,
|
||||
SubjectKind: auth.UsersKind,
|
||||
Subject: tc.idResp.GetId(),
|
||||
@@ -269,7 +296,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
Object: tc.idResp.GetDomainId(),
|
||||
ObjectType: auth.DomainType,
|
||||
}).Return(tc.authzResp, tc.authzErr)
|
||||
repocall2 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
|
||||
authCall1 := authsvc.On("Authorize", context.Background(), &magistrala.AuthorizeReq{
|
||||
SubjectType: auth.UserType,
|
||||
SubjectKind: auth.TokenKind,
|
||||
Subject: tc.token,
|
||||
@@ -277,8 +304,9 @@ func TestCreateGroup(t *testing.T) {
|
||||
Object: tc.group.Parent,
|
||||
ObjectType: auth.GroupType,
|
||||
}).Return(tc.authzTknResp, tc.authzTknErr)
|
||||
repocall3 := repo.On("Save", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
|
||||
repocall4 := authsvc.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPolResp, tc.addPolErr)
|
||||
repocall1 := repo.On("Save", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr)
|
||||
authCall2 := authsvc.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPolResp, tc.addPolErr)
|
||||
authCall3 := authsvc.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePolResp, tc.deletePolErr)
|
||||
got, err := svc.CreateGroup(context.Background(), tc.token, tc.kind, tc.group)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
|
||||
if err == nil {
|
||||
@@ -286,14 +314,15 @@ func TestCreateGroup(t *testing.T) {
|
||||
assert.NotEmpty(t, got.CreatedAt)
|
||||
assert.NotEmpty(t, got.Domain)
|
||||
assert.WithinDuration(t, time.Now(), got.CreatedAt, 2*time.Second)
|
||||
ok := repocall3.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
ok := repocall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
repocall.Unset()
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
repocall1.Unset()
|
||||
repocall2.Unset()
|
||||
repocall3.Unset()
|
||||
repocall4.Unset()
|
||||
authCall2.Unset()
|
||||
authCall3.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,9 +31,6 @@ var (
|
||||
// ErrFailedOpDB indicates a failure in a database operation.
|
||||
ErrFailedOpDB = errors.New("operation on db element failed")
|
||||
|
||||
// ErrRollbackTx indicates failed to rollback transaction.
|
||||
ErrRollbackTx = errors.New("failed to rollback transaction")
|
||||
|
||||
// ErrFailedToRetrieveAllGroups failed to retrieve groups.
|
||||
ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups")
|
||||
)
|
||||
|
||||
@@ -20,4 +20,7 @@ var (
|
||||
|
||||
// ErrStatusAlreadyAssigned indicated that the client or group has already been assigned the status.
|
||||
ErrStatusAlreadyAssigned = errors.New("status already assigned")
|
||||
|
||||
// ErrRollbackTx indicates failed to rollback transaction.
|
||||
ErrRollbackTx = errors.New("failed to rollback transaction")
|
||||
)
|
||||
|
||||
@@ -137,21 +137,23 @@ func TestCreateChannel(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err)
|
||||
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
|
||||
repoCall := grepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err)
|
||||
rChannel, err := mgsdk.CreateChannel(tc.channel, validToken)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, rChannel, fmt.Sprintf("%s: expected not nil on client ID", tc.desc))
|
||||
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
authCall2.Unset()
|
||||
authCall3.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -138,21 +138,23 @@ func TestCreateGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err)
|
||||
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
|
||||
repoCall := grepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err)
|
||||
rGroup, err := mgsdk.CreateGroup(tc.group, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, rGroup, fmt.Sprintf("%s: expected not nil on client ID", tc.desc))
|
||||
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
authCall2.Unset()
|
||||
authCall3.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+20
-16
@@ -183,9 +183,10 @@ func TestCreateThing(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
|
||||
repoCall3 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response), tc.repoErr)
|
||||
rThing, err := mgsdk.CreateThing(tc.client, tc.token)
|
||||
|
||||
@@ -200,9 +201,10 @@ func TestCreateThing(t *testing.T) {
|
||||
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
authCall2.Unset()
|
||||
authCall3.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
@@ -270,12 +272,13 @@ func TestCreateThings(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
repoCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall3 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
|
||||
authCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, nil)
|
||||
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(&magistrala.AddPoliciesRes{Added: true}, nil)
|
||||
authCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
authCall3 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(&magistrala.DeletePoliciesRes{Deleted: false}, nil)
|
||||
repoCall1 := cRepo.On("Save", mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
|
||||
if len(tc.things) > 0 {
|
||||
repoCall3 = cRepo.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
|
||||
repoCall1 = cRepo.On("Save", mock.Anything, mock.Anything, mock.Anything).Return(convertThings(tc.response...), tc.err)
|
||||
}
|
||||
rThing, err := mgsdk.CreateThings(tc.things, tc.token)
|
||||
for i, t := range rThing {
|
||||
@@ -290,17 +293,18 @@ func TestCreateThings(t *testing.T) {
|
||||
if tc.err == nil {
|
||||
switch len(tc.things) {
|
||||
case 1:
|
||||
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
case 2:
|
||||
ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything, mock.Anything)
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
}
|
||||
repoCall.Unset()
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
authCall2.Unset()
|
||||
authCall3.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+69
-24
@@ -93,35 +93,22 @@ func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclie
|
||||
clients = append(clients, c)
|
||||
}
|
||||
|
||||
if err := svc.addThingPolicies(ctx, user.GetId(), user.GetDomainId(), clients); err != nil {
|
||||
return []mgclients.Client{}, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if errRollback := svc.addThingPoliciesRollback(ctx, user.GetId(), user.GetDomainId(), clients); errRollback != nil {
|
||||
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
saved, err := svc.clients.Save(ctx, clients...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
policies := magistrala.AddPoliciesReq{}
|
||||
for _, c := range saved {
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: user.GetDomainId(),
|
||||
SubjectType: auth.UserType,
|
||||
Subject: user.GetId(),
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: auth.NewThingKind,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: c.ID,
|
||||
})
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: user.GetDomainId(),
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: user.GetDomainId(),
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: c.ID,
|
||||
})
|
||||
}
|
||||
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return saved, nil
|
||||
}
|
||||
|
||||
@@ -609,3 +596,61 @@ func (svc *service) authorize(ctx context.Context, domainID, subjType, subjKind,
|
||||
|
||||
return res.GetId(), nil
|
||||
}
|
||||
|
||||
func (svc service) addThingPolicies(ctx context.Context, userID, domainID string, things []mgclients.Client) error {
|
||||
policies := magistrala.AddPoliciesReq{}
|
||||
for _, thing := range things {
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: auth.NewThingKind,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: thing.ID,
|
||||
})
|
||||
policies.AddPoliciesReq = append(policies.AddPoliciesReq, &magistrala.AddPolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: domainID,
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: thing.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) addThingPoliciesRollback(ctx context.Context, userID, domainID string, things []mgclients.Client) error {
|
||||
policies := magistrala.DeletePoliciesReq{}
|
||||
for _, thing := range things {
|
||||
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.UserType,
|
||||
Subject: userID,
|
||||
Relation: auth.AdministratorRelation,
|
||||
ObjectKind: auth.NewThingKind,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: thing.ID,
|
||||
})
|
||||
policies.DeletePoliciesReq = append(policies.DeletePoliciesReq, &magistrala.DeletePolicyReq{
|
||||
Domain: domainID,
|
||||
SubjectType: auth.DomainType,
|
||||
Subject: domainID,
|
||||
Relation: auth.DomainRelation,
|
||||
ObjectType: auth.ThingType,
|
||||
Object: thing.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := svc.auth.DeletePolicies(ctx, &policies); err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+26
-5
@@ -65,9 +65,11 @@ func TestCreateThings(t *testing.T) {
|
||||
token string
|
||||
authResponse *magistrala.AuthorizeRes
|
||||
addPolicyResponse *magistrala.AddPoliciesRes
|
||||
deletePolicyRes *magistrala.DeletePoliciesRes
|
||||
authorizeErr error
|
||||
identifyErr error
|
||||
addPolicyErr error
|
||||
deletePolicyErr error
|
||||
saveErr error
|
||||
err error
|
||||
}{
|
||||
@@ -305,13 +307,31 @@ func TestCreateThings(t *testing.T) {
|
||||
addPolicyErr: svcerr.ErrInvalidPolicy,
|
||||
err: svcerr.ErrInvalidPolicy,
|
||||
},
|
||||
{
|
||||
desc: "create a new thing with failed delete policy response",
|
||||
thing: mgclients.Client{
|
||||
Credentials: mgclients.Credentials{
|
||||
Identity: "newclientwithfailedpolicy@example.com",
|
||||
Secret: secret,
|
||||
},
|
||||
Status: mgclients.EnabledStatus,
|
||||
},
|
||||
token: validToken,
|
||||
authResponse: &magistrala.AuthorizeRes{Authorized: true},
|
||||
addPolicyResponse: &magistrala.AddPoliciesRes{Added: true},
|
||||
saveErr: repoerr.ErrConflict,
|
||||
deletePolicyRes: &magistrala.DeletePoliciesRes{Deleted: false},
|
||||
deletePolicyErr: svcerr.ErrInvalidPolicy,
|
||||
err: repoerr.ErrConflict,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID, DomainId: testsutil.GenerateUUID(t)}, tc.identifyErr)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(tc.authResponse, tc.authorizeErr)
|
||||
repoCall2 := cRepo.On("Save", context.Background(), mock.Anything).Return([]mgclients.Client{tc.thing}, tc.saveErr)
|
||||
repoCall3 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyResponse, tc.addPolicyErr)
|
||||
authcall := auth.On("Authorize", mock.Anything, mock.Anything).Return(tc.authResponse, tc.authorizeErr)
|
||||
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return([]mgclients.Client{tc.thing}, tc.saveErr)
|
||||
authCall1 := auth.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyResponse, tc.addPolicyErr)
|
||||
authCall2 := auth.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePolicyRes, tc.deletePolicyErr)
|
||||
expected, err := svc.CreateThings(context.Background(), tc.token, tc.thing)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@@ -324,9 +344,10 @@ func TestCreateThings(t *testing.T) {
|
||||
assert.Equal(t, tc.thing, expected[0], fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.thing, expected[0]))
|
||||
}
|
||||
repoCall.Unset()
|
||||
authcall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
authCall1.Unset()
|
||||
authCall2.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -84,7 +84,7 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if errRollback := svc.addClientPolicyRollback(ctx, cli.ID, cli.Role); errRollback != nil {
|
||||
err = errors.Wrap(errors.Wrap(repoerr.ErrRollbackTx, errRollback), err)
|
||||
err = errors.Wrap(errors.Wrap(errors.ErrRollbackTx, errRollback), err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user