NOISSUE - Remove domain from token (#2468)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2024-10-25 15:41:51 +03:00
committed by GitHub
parent ca8ed3b006
commit f88e11bdb2
76 changed files with 1619 additions and 1831 deletions
+8
View File
@@ -48,6 +48,8 @@ paths:
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Unauthorized access to the domain ID.
"404":
description: A non-existent entity request.
"409":
@@ -86,6 +88,8 @@ paths:
description: |
Missing or invalid access token provided.
This endpoint is available only for administrators.
"403":
description: Unauthorized access to the domain ID.
"404":
description: A non-existent entity request.
"422":
@@ -165,6 +169,8 @@ paths:
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"403":
description: Unauthorized access to the domain ID.
"404":
description: A non-existent entity request.
"415":
@@ -191,6 +197,8 @@ paths:
description: Invitation deleted.
"400":
description: Failed due to malformed JSON.
"403":
description: Unauthorized access to the domain ID.
"404":
description: Failed due to non existing user.
"401":
+7 -31
View File
@@ -204,9 +204,8 @@ type IssueReq struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
DomainId *string `protobuf:"bytes,2,opt,name=domain_id,json=domainId,proto3,oneof" json:"domain_id,omitempty"`
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
Type uint32 `protobuf:"varint,2,opt,name=type,proto3" json:"type,omitempty"`
}
func (x *IssueReq) Reset() {
@@ -248,13 +247,6 @@ func (x *IssueReq) GetUserId() string {
return ""
}
func (x *IssueReq) GetDomainId() string {
if x != nil && x.DomainId != nil {
return *x.DomainId
}
return ""
}
func (x *IssueReq) GetType() uint32 {
if x != nil {
return x.Type
@@ -267,8 +259,7 @@ type RefreshReq struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
DomainId *string `protobuf:"bytes,2,opt,name=domain_id,json=domainId,proto3,oneof" json:"domain_id,omitempty"`
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
}
func (x *RefreshReq) Reset() {
@@ -310,13 +301,6 @@ func (x *RefreshReq) GetRefreshToken() string {
return ""
}
func (x *RefreshReq) GetDomainId() string {
if x != nil && x.DomainId != nil {
return *x.DomainId
}
return ""
}
type AuthZReq struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@@ -723,19 +707,13 @@ var file_auth_proto_rawDesc = []byte{
0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75,
0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f,
0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
0x49, 0x64, 0x22, 0x67, 0x0a, 0x08, 0x49, 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x71, 0x12, 0x17,
0x49, 0x64, 0x22, 0x37, 0x0a, 0x08, 0x49, 0x73, 0x73, 0x75, 0x65, 0x52, 0x65, 0x71, 0x12, 0x17,
0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x09, 0x64, 0x6f, 0x6d, 0x61, 0x69,
0x6e, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x08, 0x64, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70,
0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x42, 0x0c, 0x0a,
0x0a, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x22, 0x61, 0x0a, 0x0a, 0x52,
0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18,
0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x31, 0x0a, 0x0a, 0x52,
0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, 0x66,
0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x0c, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x20,
0x0a, 0x09, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
0x09, 0x48, 0x00, 0x52, 0x08, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01,
0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x22, 0xa2,
0x52, 0x0c, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x22, 0xa2,
0x02, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x12, 0x16, 0x0a, 0x06, 0x64,
0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x6f, 0x6d,
0x61, 0x69, 0x6e, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x75, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x5f, 0x74,
@@ -993,8 +971,6 @@ func file_auth_proto_init() {
}
}
file_auth_proto_msgTypes[0].OneofWrappers = []any{}
file_auth_proto_msgTypes[3].OneofWrappers = []any{}
file_auth_proto_msgTypes[4].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
+1 -3
View File
@@ -53,13 +53,11 @@ message AuthNRes {
message IssueReq {
string user_id = 1;
optional string domain_id = 2;
uint32 type = 3;
uint32 type = 2;
}
message RefreshReq {
string refresh_token = 1;
optional string domain_id = 2;
}
message AuthZReq {
+6 -8
View File
@@ -53,9 +53,8 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *magistrala.IssueRe
defer cancel()
res, err := client.issue(ctx, issueReq{
userID: req.GetUserId(),
domainID: req.GetDomainId(),
keyType: auth.KeyType(req.GetType()),
userID: req.GetUserId(),
keyType: auth.KeyType(req.GetType()),
})
if err != nil {
return &magistrala.Token{}, grpcapi.DecodeError(err)
@@ -66,9 +65,8 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *magistrala.IssueRe
func encodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(issueReq)
return &magistrala.IssueReq{
UserId: req.userID,
DomainId: &req.domainID,
Type: uint32(req.keyType),
UserId: req.userID,
Type: uint32(req.keyType),
}, nil
}
@@ -80,7 +78,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *magistrala.Refre
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken(), domainID: req.GetDomainId()})
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken()})
if err != nil {
return &magistrala.Token{}, grpcapi.DecodeError(err)
}
@@ -89,7 +87,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *magistrala.Refre
func encodeRefreshRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(refreshReq)
return &magistrala.RefreshReq{RefreshToken: req.refreshToken, DomainId: &req.domainID}, nil
return &magistrala.RefreshReq{RefreshToken: req.refreshToken}, nil
}
func decodeRefreshResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
+3 -4
View File
@@ -18,9 +18,8 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
}
key := auth.Key{
Type: req.keyType,
User: req.userID,
Domain: req.domainID,
Type: req.keyType,
User: req.userID,
}
tkn, err := svc.Issue(ctx, "", key)
if err != nil {
@@ -42,7 +41,7 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint {
return issueRes{}, err
}
key := auth.Key{Domain: req.domainID, Type: auth.RefreshKey}
key := auth.Key{Type: auth.RefreshKey}
tkn, err := svc.Issue(ctx, req.refreshToken, key)
if err != nil {
return issueRes{}, err
+20 -27
View File
@@ -46,7 +46,6 @@ const (
var (
validID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
authAddr = fmt.Sprintf("localhost:%d", port)
)
@@ -70,16 +69,14 @@ func TestIssue(t *testing.T) {
cases := []struct {
desc string
userId string
domainID string
kind auth.KeyType
issueResponse auth.Token
err error
}{
{
desc: "issue for user with valid token",
userId: validID,
domainID: domainID,
kind: auth.AccessKey,
desc: "issue for user with valid token",
userId: validID,
kind: auth.AccessKey,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
@@ -87,10 +84,9 @@ func TestIssue(t *testing.T) {
err: nil,
},
{
desc: "issue recovery key",
userId: validID,
domainID: domainID,
kind: auth.RecoveryKey,
desc: "issue recovery key",
userId: validID,
kind: auth.RecoveryKey,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
@@ -100,7 +96,6 @@ func TestIssue(t *testing.T) {
{
desc: "issue API key unauthenticated",
userId: validID,
domainID: domainID,
kind: auth.APIKey,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
@@ -108,7 +103,6 @@ func TestIssue(t *testing.T) {
{
desc: "issue for invalid key type",
userId: validID,
domainID: domainID,
kind: 32,
issueResponse: auth.Token{},
err: errors.ErrMalformedEntity,
@@ -116,7 +110,6 @@ func TestIssue(t *testing.T) {
{
desc: "issue for user that does notexist",
userId: "",
domainID: "",
kind: auth.APIKey,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
@@ -124,10 +117,12 @@ func TestIssue(t *testing.T) {
}
for _, tc := range cases {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Issue(context.Background(), &magistrala.IssueReq{UserId: tc.userId, DomainId: &tc.domainID, Type: uint32(tc.kind)})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Issue(context.Background(), &magistrala.IssueReq{UserId: tc.userId, Type: uint32(tc.kind)})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
@@ -139,14 +134,12 @@ func TestRefresh(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
issueResponse auth.Token
err error
}{
{
desc: "refresh token with valid token",
token: validToken,
domainID: domainID,
desc: "refresh token with valid token",
token: validToken,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
@@ -156,23 +149,23 @@ func TestRefresh(t *testing.T) {
{
desc: "refresh token with invalid token",
token: inValidToken,
domainID: domainID,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
},
{
desc: "refresh token with empty token",
token: "",
domainID: domainID,
issueResponse: auth.Token{},
err: apiutil.ErrMissingSecret,
},
}
for _, tc := range cases {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Refresh(context.Background(), &magistrala.RefreshReq{DomainId: &tc.domainID, RefreshToken: tc.token})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Refresh(context.Background(), &magistrala.RefreshReq{RefreshToken: tc.token})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
+2 -4
View File
@@ -9,9 +9,8 @@ import (
)
type issueReq struct {
userID string
domainID string // optional
keyType auth.KeyType
userID string
keyType auth.KeyType
}
func (req issueReq) validate() error {
@@ -27,7 +26,6 @@ func (req issueReq) validate() error {
type refreshReq struct {
refreshToken string
domainID string // optional
}
func (req refreshReq) validate() error {
+3 -4
View File
@@ -55,15 +55,14 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *magistrala.RefreshRe
func decodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*magistrala.IssueReq)
return issueReq{
userID: req.GetUserId(),
domainID: req.GetDomainId(),
keyType: auth.KeyType(req.GetType()),
userID: req.GetUserId(),
keyType: auth.KeyType(req.GetType()),
}, nil
}
func decodeRefreshRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*magistrala.RefreshReq)
return refreshReq{refreshToken: req.GetRefreshToken(), domainID: req.GetDomainId()}, nil
return refreshReq{refreshToken: req.GetRefreshToken()}, nil
}
func encodeIssueResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
+14 -10
View File
@@ -119,11 +119,13 @@ func TestIssue(t *testing.T) {
},
}
for _, tc := range cases {
tkn, err := tokenizer.Issue(tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err != nil {
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
}
t.Run(tc.desc, func(t *testing.T) {
tkn, err := tokenizer.Issue(tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err != nil {
assert.NotEmpty(t, tkn, fmt.Sprintf("%s expected token, got empty string", tc.desc))
}
})
}
}
@@ -225,11 +227,13 @@ func TestParse(t *testing.T) {
}
for _, tc := range cases {
key, err := tokenizer.Parse(tc.token)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
}
t.Run(tc.desc, func(t *testing.T) {
key, err := tokenizer.Parse(tc.token)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key))
}
})
}
}
-4
View File
@@ -32,7 +32,6 @@ const (
issuerName = "magistrala.auth"
tokenType = "type"
userField = "user"
domainField = "domain"
oauthProviderField = "oauth_provider"
oauthAccessTokenField = "access_token"
oauthRefreshTokenField = "refresh_token"
@@ -59,9 +58,6 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(userField, key.User)
if key.Domain != "" {
builder.Claim(domainField, key.Domain)
}
if key.Subject != "" {
builder.Subject(key.Subject)
}
+16 -6
View File
@@ -613,10 +613,15 @@ func (svc service) ListDomains(ctx context.Context, token string, p Page) (Domai
}
func (svc service) AssignUsers(ctx context.Context, token, id string, userIds []string, relation string) error {
res, err := svc.Identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: token,
Subject: res.User,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
@@ -625,9 +630,9 @@ func (svc service) AssignUsers(ctx context.Context, token, id string, userIds []
}
if err := svc.Authorize(ctx, policies.Policy{
Subject: token,
Subject: res.User,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: SwitchToPermission(relation),
@@ -651,10 +656,15 @@ func (svc service) AssignUsers(ctx context.Context, token, id string, userIds []
}
func (svc service) UnassignUser(ctx context.Context, token, id, userID string) error {
res, err := svc.Identify(ctx, token)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
pr := policies.Policy{
Subject: token,
Subject: res.User,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: id,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
+248 -237
View File
@@ -308,18 +308,20 @@ func TestIssue(t *testing.T) {
},
}
for _, tc := range cases2 {
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPlatformPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(tc.retrieveByIDResponse, tc.retreiveByIDErr)
repoCall4 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr)
_, err := svc.Issue(context.Background(), tc.token, tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPlatformPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(tc.retrieveByIDResponse, tc.retreiveByIDErr)
repoCall4 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr)
_, err := svc.Issue(context.Background(), tc.token, tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
})
}
cases3 := []struct {
@@ -405,6 +407,7 @@ func TestIssue(t *testing.T) {
key: auth.Key{
Type: auth.RefreshKey,
IssuedAt: time.Now(),
Domain: groupName,
},
checkPolicyRequest: policies.Policy{
Subject: email,
@@ -554,10 +557,12 @@ func TestRevoke(t *testing.T) {
}
for _, tc := range cases {
repocall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
err := svc.Revoke(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repocall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
err := svc.Revoke(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
})
}
}
@@ -624,10 +629,12 @@ func TestRetrieve(t *testing.T) {
}
for _, tc := range cases {
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{}, tc.err)
_, err := svc.RetrieveKey(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{}, tc.err)
_, err := svc.RetrieveKey(context.Background(), tc.token, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repocall.Unset()
})
}
}
@@ -726,13 +733,15 @@ func TestIdentify(t *testing.T) {
}
for _, tc := range cases {
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{}, tc.err)
repocall1 := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
idt, err := svc.Identify(context.Background(), tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.idt, idt.Subject, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.idt, idt))
repocall.Unset()
repocall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{}, tc.err)
repocall1 := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
idt, err := svc.Identify(context.Background(), tc.key)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.idt, idt.Subject, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.idt, idt))
repocall.Unset()
repocall1.Unset()
})
}
}
@@ -791,7 +800,7 @@ func TestAuthorize(t *testing.T) {
Permission: policies.AdminPermission,
},
checkPolicyReq3: policies.Policy{
Domain: groupName,
Domain: "",
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
@@ -1164,18 +1173,20 @@ func TestAuthorize(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq3).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(tc.retrieveDomainRes, nil)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
repoCall4 := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := svc.Authorize(context.Background(), tc.policyReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq3).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(tc.retrieveDomainRes, nil)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
repoCall4 := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := svc.Authorize(context.Background(), tc.policyReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
})
}
cases2 := []struct {
desc string
@@ -1195,8 +1206,10 @@ func TestAuthorize(t *testing.T) {
},
}
for _, tc := range cases2 {
err := svc.Authorize(context.Background(), tc.policyReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
t.Run(tc.desc, func(t *testing.T) {
err := svc.Authorize(context.Background(), tc.policyReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
@@ -1238,8 +1251,10 @@ func TestSwitchToPermission(t *testing.T) {
},
}
for _, tc := range cases {
result := auth.SwitchToPermission(tc.relation)
assert.Equal(t, tc.result, result, fmt.Sprintf("switching to permission expected to succeed: %s", result))
t.Run(tc.desc, func(t *testing.T) {
result := auth.SwitchToPermission(tc.relation)
assert.Equal(t, tc.result, result, fmt.Sprintf("switching to permission expected to succeed: %s", result))
})
}
}
@@ -1353,18 +1368,20 @@ func TestCreateDomain(t *testing.T) {
}
for _, tc := range cases {
repoCall := pService.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyErr)
repoCall1 := drepo.On("SavePolicies", mock.Anything, mock.Anything).Return(tc.savePolicyErr)
repoCall2 := pService.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
repoCall3 := drepo.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deleteDomainErr)
repoCall4 := drepo.On("Save", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.saveDomainErr)
_, err := svc.CreateDomain(context.Background(), tc.token, tc.d)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pService.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPolicyErr)
repoCall1 := drepo.On("SavePolicies", mock.Anything, mock.Anything).Return(tc.savePolicyErr)
repoCall2 := pService.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
repoCall3 := drepo.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deleteDomainErr)
repoCall4 := drepo.On("Save", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.saveDomainErr)
_, err := svc.CreateDomain(context.Background(), tc.token, tc.d)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
})
}
}
@@ -1417,14 +1434,16 @@ func TestRetrieveDomain(t *testing.T) {
}
for _, tc := range cases {
repoCall := drepo.On("RetrieveByID", mock.Anything, groupName).Return(auth.Domain{}, tc.domainRepoErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("RetrieveByID", mock.Anything, tc.domainID).Return(auth.Domain{}, tc.domainRepoErr1)
_, err := svc.RetrieveDomain(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, groupName).Return(auth.Domain{}, tc.domainRepoErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("RetrieveByID", mock.Anything, tc.domainID).Return(auth.Domain{}, tc.domainRepoErr1)
_, err := svc.RetrieveDomain(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1476,14 +1495,16 @@ func TestRetrieveDomainPermissions(t *testing.T) {
}
for _, tc := range cases {
repoCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(policies.Permissions{}, tc.retreivePermissionsErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreiveByIDErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
_, err := svc.RetrieveDomainPermissions(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(policies.Permissions{}, tc.retreivePermissionsErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreiveByIDErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
_, err := svc.RetrieveDomainPermissions(context.Background(), tc.token, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1556,14 +1577,16 @@ func TestUpdateDomain(t *testing.T) {
}
for _, tc := range cases {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.UpdateDomain(context.Background(), tc.token, tc.domainID, tc.domReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.UpdateDomain(context.Background(), tc.token, tc.domainID, tc.domReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1633,14 +1656,16 @@ func TestChangeDomainStatus(t *testing.T) {
}
for _, tc := range cases {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreieveByIDErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.ChangeDomainStatus(context.Background(), tc.token, tc.domainID, tc.domainReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreieveByIDErr)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr)
_, err := svc.ChangeDomainStatus(context.Background(), tc.token, tc.domainID, tc.domainReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1701,12 +1726,14 @@ func TestListDomains(t *testing.T) {
}
for _, tc := range cases {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("ListDomains", mock.Anything, mock.Anything).Return(tc.listDomainsRes, tc.listDomainErr)
_, err := svc.ListDomains(context.Background(), tc.token, auth.Page{})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("ListDomains", mock.Anything, mock.Anything).Return(tc.listDomainsRes, tc.listDomainErr)
_, err := svc.ListDomains(context.Background(), tc.token, auth.Page{})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
@@ -1738,19 +1765,17 @@ func TestAssignUsers(t *testing.T) {
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1763,13 +1788,12 @@ func TestAssignUsers(t *testing.T) {
Permission: policies.MembershipPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
err: nil,
},
{
@@ -1779,19 +1803,18 @@ func TestAssignUsers(t *testing.T) {
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1811,27 +1834,25 @@ func TestAssignUsers(t *testing.T) {
domainID: inValid,
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -1845,19 +1866,17 @@ func TestAssignUsers(t *testing.T) {
domainID: validID,
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1870,9 +1889,9 @@ func TestAssignUsers(t *testing.T) {
Permission: policies.MembershipPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -1886,19 +1905,17 @@ func TestAssignUsers(t *testing.T) {
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1911,9 +1928,9 @@ func TestAssignUsers(t *testing.T) {
Permission: policies.MembershipPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -1927,19 +1944,17 @@ func TestAssignUsers(t *testing.T) {
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1952,9 +1967,9 @@ func TestAssignUsers(t *testing.T) {
Permission: policies.MembershipPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -1968,19 +1983,17 @@ func TestAssignUsers(t *testing.T) {
userIDs: []string{validID},
relation: policies.ContributorRelation,
checkPolicyReq3: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.ViewPermission,
@@ -1993,9 +2006,9 @@ func TestAssignUsers(t *testing.T) {
Permission: policies.MembershipPermission,
},
checkPolicyReq33: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -2006,24 +2019,26 @@ func TestAssignUsers(t *testing.T) {
}
for _, tc := range cases {
repoCall := drepo.On("RetrieveByID", mock.Anything, groupName).Return(auth.Domain{}, nil)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq3).Return(tc.checkpolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr2)
repoCall4 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq33).Return(tc.checkPolicyErr2)
repoCall5 := pService.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr)
repoCall6 := drepo.On("SavePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.savePoliciesErr)
repoCall7 := pService.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
err := svc.AssignUsers(context.Background(), tc.token, tc.domainID, tc.userIDs, tc.relation)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
repoCall5.Unset()
repoCall6.Unset()
repoCall7.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq3).Return(tc.checkpolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr2)
repoCall4 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq33).Return(tc.checkPolicyErr2)
repoCall5 := pService.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr)
repoCall6 := drepo.On("SavePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.savePoliciesErr)
repoCall7 := pService.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
err := svc.AssignUsers(context.Background(), tc.token, tc.domainID, tc.userIDs, tc.relation)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
repoCall5.Unset()
repoCall6.Unset()
repoCall7.Unset()
})
}
}
@@ -2050,26 +2065,24 @@ func TestUnassignUser(t *testing.T) {
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
@@ -2082,19 +2095,17 @@ func TestUnassignUser(t *testing.T) {
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
@@ -2107,27 +2118,25 @@ func TestUnassignUser(t *testing.T) {
domainID: inValid,
userID: validID,
checkPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: inValid,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -2140,27 +2149,25 @@ func TestUnassignUser(t *testing.T) {
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -2173,27 +2180,25 @@ func TestUnassignUser(t *testing.T) {
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
@@ -2207,26 +2212,24 @@ func TestUnassignUser(t *testing.T) {
domainID: validID,
userID: validID,
checkPolicyReq: policies.Policy{
Subject: id,
Subject: email,
SubjectType: policies.UserType,
Object: groupName,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.MembershipPermission,
},
checkAdminPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.AdminPermission,
},
checkDomainPolicyReq: policies.Policy{
Domain: groupName,
Subject: id,
Subject: email,
SubjectType: policies.UserType,
SubjectKind: policies.TokenKind,
SubjectKind: policies.UsersKind,
Object: validID,
ObjectType: policies.DomainType,
Permission: policies.SharePermission,
@@ -2237,20 +2240,22 @@ func TestUnassignUser(t *testing.T) {
}
for _, tc := range cases {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
repoCall4 := pService.On("DeletePolicyFilter", mock.Anything, mock.Anything).Return(tc.deletePolicyFilterErr)
repoCall5 := drepo.On("DeletePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
err := svc.UnassignUser(context.Background(), tc.token, tc.domainID, tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
repoCall5.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil)
repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr)
repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1)
repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1)
repoCall4 := pService.On("DeletePolicyFilter", mock.Anything, mock.Anything).Return(tc.deletePolicyFilterErr)
repoCall5 := drepo.On("DeletePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
err := svc.UnassignUser(context.Background(), tc.token, tc.domainID, tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
repoCall5.Unset()
})
}
}
@@ -2327,12 +2332,14 @@ func TestListUsersDomains(t *testing.T) {
}
for _, tc := range cases {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("ListDomains", mock.Anything, mock.Anything).Return(auth.DomainsPage{}, tc.listDomainErr)
_, err := svc.ListUserDomains(context.Background(), tc.token, tc.userID, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr)
repoCall1 := drepo.On("ListDomains", mock.Anything, mock.Anything).Return(auth.DomainsPage{}, tc.listDomainErr)
_, err := svc.ListUserDomains(context.Background(), tc.token, tc.userID, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
@@ -2370,8 +2377,10 @@ func TestEncodeDomainUserID(t *testing.T) {
}
for _, tc := range cases {
ar := auth.EncodeDomainUserID(tc.domainID, tc.userID)
assert.Equal(t, tc.response, ar, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.response, ar))
t.Run(tc.desc, func(t *testing.T) {
ar := auth.EncodeDomainUserID(tc.domainID, tc.userID)
assert.Equal(t, tc.response, ar, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.response, ar))
})
}
}
@@ -2409,8 +2418,10 @@ func TestDecodeDomainUserID(t *testing.T) {
}
for _, tc := range cases {
ar, er := auth.DecodeDomainUserID(tc.domainUserID)
assert.Equal(t, tc.respUserID, er, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.respUserID, er))
assert.Equal(t, tc.respDomainID, ar, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.respDomainID, ar))
t.Run(tc.desc, func(t *testing.T) {
ar, er := auth.DecodeDomainUserID(tc.domainUserID)
assert.Equal(t, tc.respUserID, er, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.respUserID, er))
assert.Equal(t, tc.respDomainID, ar, fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.respDomainID, ar))
})
}
}
+8 -8
View File
@@ -324,7 +324,7 @@ func TestAdd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
@@ -425,7 +425,7 @@ func TestView(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("View", mock.Anything, tc.session, tc.id).Return(c, tc.err)
@@ -542,7 +542,7 @@ func TestUpdate(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("Update", mock.Anything, tc.session, mock.Anything).Return(tc.err)
@@ -650,7 +650,7 @@ func TestUpdateCert(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("UpdateCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(c, tc.err)
@@ -771,7 +771,7 @@ func TestUpdateConnections(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
repoCall := svc.On("UpdateConnections", mock.Anything, tc.session, tc.token, mock.Anything, mock.Anything).Return(tc.err)
@@ -1040,7 +1040,7 @@ func TestList(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bootstrap.ConfigsPage{Total: tc.res.Total, Offset: tc.res.Offset, Limit: tc.res.Limit}, tc.err)
@@ -1123,7 +1123,7 @@ func TestRemove(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
@@ -1372,7 +1372,7 @@ func TestChangeState(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ChangeState", mock.Anything, tc.session, tc.token, mock.Anything, mock.Anything).Return(tc.err)
+2 -2
View File
@@ -49,7 +49,7 @@ func MakeHandler(svc bootstrap.Service, authn mgauthn.Authentication, reader boo
r.Route("/{domainID}/things", func(r chi.Router) {
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/configs", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
@@ -96,7 +96,7 @@ func MakeHandler(svc bootstrap.Service, authn mgauthn.Authentication, reader boo
})
})
r.With(api.AuthenticateMiddleware(authn)).Put("/state/{thingID}", otelhttp.NewHandler(kithttp.NewServer(
r.With(api.AuthenticateMiddleware(authn, true)).Put("/state/{thingID}", otelhttp.NewHandler(kithttp.NewServer(
stateEndpoint(svc),
decodeStateRequest,
api.EncodeResponse,
+1 -1
View File
@@ -40,7 +40,7 @@ func MakeHandler(svc certs.Service, authn mgauthn.Authentication, logger *slog.L
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/{domainID}", func(r chi.Router) {
r.Route("/certs", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+5 -6
View File
@@ -48,14 +48,13 @@ var cmdBootstrap = []cobra.Command{
return
}
pageMetadata := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
State: State,
Name: Name,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
State: State,
Name: Name,
}
if args[0] == "all" {
l, err := sdk.Bootstraps(pageMetadata, args[2])
l, err := sdk.Bootstraps(pageMetadata, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
+1 -1
View File
@@ -189,7 +189,7 @@ func TestGetBootstrapConfigCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ViewBootstrap", tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr)
sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...)
+10 -13
View File
@@ -63,7 +63,7 @@ var cmdChannels = []cobra.Command{
}
if args[0] == all {
l, err := sdk.Channels(pageMetadata, args[2])
l, err := sdk.Channels(pageMetadata, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -134,11 +134,10 @@ var cmdChannels = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
}
cl, err := sdk.ThingsByChannel(args[0], pm, args[2])
cl, err := sdk.ThingsByChannel(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -197,11 +196,10 @@ var cmdChannels = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
}
ul, err := sdk.ListChannelUsers(args[0], pm, args[2])
ul, err := sdk.ListChannelUsers(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -222,11 +220,10 @@ var cmdChannels = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
}
ul, err := sdk.ListChannelUserGroups(args[0], pm, args[2])
ul, err := sdk.ListChannelUserGroups(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
+4 -4
View File
@@ -182,7 +182,7 @@ func TestGetChannelsCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("Channel", tc.args[0], tc.args[1], tc.args[2]).Return(tc.channel, tc.sdkErr)
sdkCall1 := sdkMock.On("Channels", mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall1 := sdkMock.On("Channels", mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...)
@@ -425,7 +425,7 @@ func TestListConnectionsCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ThingsByChannel", tc.args[0], mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ThingsByChannel", tc.args[0], mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{connsCmd}, tc.args...)...)
switch tc.logType {
case entityLog:
@@ -665,7 +665,7 @@ func TestUsersChannelCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ListChannelUsers", tc.args[0], mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ListChannelUsers", tc.args[0], mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{usrCmd}, tc.args...)...)
switch tc.logType {
@@ -741,7 +741,7 @@ func TestListGroupCmd(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ListChannelUserGroups", tc.args[0], mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ListChannelUserGroups", tc.args[0], mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{grpCmd}, tc.args...)...)
switch tc.logType {
case entityLog:
+15 -19
View File
@@ -84,11 +84,10 @@ var cmdGroups = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
}
l, err := sdk.Groups(pm, args[2])
l, err := sdk.Groups(pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -106,7 +105,7 @@ var cmdGroups = []cobra.Command{
Limit: Limit,
DomainID: args[2],
}
l, err := sdk.Children(args[1], pm, args[3])
l, err := sdk.Children(args[1], pm, args[2], args[3])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -120,11 +119,10 @@ var cmdGroups = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[2],
Offset: Offset,
Limit: Limit,
}
l, err := sdk.Parents(args[1], pm, args[3])
l, err := sdk.Parents(args[1], pm, args[2], args[3])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -174,12 +172,11 @@ var cmdGroups = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
Status: Status,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
Status: Status,
}
users, err := sdk.ListGroupUsers(args[0], pm, args[2])
users, err := sdk.ListGroupUsers(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -199,12 +196,11 @@ var cmdGroups = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
Status: Status,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
Status: Status,
}
channels, err := sdk.ListGroupChannels(args[0], pm, args[2])
channels, err := sdk.ListGroupChannels(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
+18 -18
View File
@@ -156,19 +156,19 @@ func TestGetGroupsCmd(t *testing.T) {
},
logType: usageLog,
},
// {
// desc: "get children groups successfully",
// args: []string{
// childCmd,
// group.ID,
// domainID,
// token,
// },
// page: mgsdk.GroupsPage{
// Groups: []mgsdk.Group{group},
// },
// logType: entityLog,
// },
{
desc: "get children groups successfully",
args: []string{
childCmd,
group.ID,
domainID,
token,
},
page: mgsdk.GroupsPage{
Groups: []mgsdk.Group{group},
},
logType: entityLog,
},
{
desc: "get children groups with invalid args",
args: []string{
@@ -293,9 +293,9 @@ func TestGetGroupsCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("Group", mock.Anything, mock.Anything, mock.Anything).Return(tc.group, tc.sdkErr)
sdkCall1 := sdkMock.On("Groups", mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall2 := sdkMock.On("Parents", mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall3 := sdkMock.On("Children", mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall1 := sdkMock.On("Groups", mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall2 := sdkMock.On("Parents", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall3 := sdkMock.On("Children", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...)
@@ -535,7 +535,7 @@ func TestListUsersCmd(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ListGroupUsers", tc.args[0], mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ListGroupUsers", tc.args[0], mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{usrCmd}, tc.args...)...)
switch tc.logType {
case entityLog:
@@ -610,7 +610,7 @@ func TestListChannelsCmd(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ListGroupChannels", tc.args[0], mock.Anything, tc.args[2]).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ListGroupChannels", tc.args[0], mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{chansCmd}, tc.args...)...)
switch tc.logType {
case entityLog:
+1 -1
View File
@@ -53,7 +53,7 @@ var cmdInvitations = []cobra.Command{
Limit: Limit,
}
if args[0] == all {
l, err := sdk.Invitations(pageMetadata, args[1])
l, err := sdk.Invitations(pageMetadata, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
+3 -1
View File
@@ -113,6 +113,7 @@ func TestGetInvitationCmd(t *testing.T) {
desc: "get all invitations successfully",
args: []string{
all,
domain.ID,
token,
},
page: mgsdk.InvitationPage{
@@ -147,6 +148,7 @@ func TestGetInvitationCmd(t *testing.T) {
desc: "get all invitations with invalid token",
args: []string{
all,
domain.ID,
invalidToken,
},
logType: errLog,
@@ -169,7 +171,7 @@ func TestGetInvitationCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("Invitation", tc.args[0], tc.args[1], mock.Anything).Return(tc.inv, tc.sdkErr)
sdkCall1 := sdkMock.On("Invitations", mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr)
sdkCall1 := sdkMock.On("Invitations", mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...)
+1 -2
View File
@@ -166,8 +166,7 @@ var cmdProvision = []cobra.Command{
return
}
// domain login
ut, err = sdk.CreateToken(mgxsdk.Login{Identity: user.Credentials.Identity, Secret: user.Credentials.Secret, DomainID: domain.ID})
ut, err = sdk.CreateToken(mgxsdk.Login{Identity: user.Credentials.Identity, Secret: user.Credentials.Secret})
if err != nil {
logErrorCmd(*cmd, err)
return
+5 -6
View File
@@ -46,7 +46,7 @@ var cmdThings = []cobra.Command{
"Usage:\n" +
"\tmagistrala-cli things get all $DOMAINID $USERTOKEN - lists all things\n" +
"\tmagistrala-cli things get all $DOMAINID $USERTOKEN --offset=10 --limit=10 - lists all things with offset and limit\n" +
"\tmagistrala-cli things get <thing_id> $USERTOKEN - shows thing with provided <thing_id>\n",
"\tmagistrala-cli things get <thing_id> $DOMAINID $USERTOKEN - shows thing with provided <thing_id>\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 {
logUsageCmd(*cmd, cmd.Use)
@@ -64,7 +64,7 @@ var cmdThings = []cobra.Command{
Metadata: metadata,
}
if args[0] == all {
l, err := sdk.Things(pageMetadata, args[1])
l, err := sdk.Things(pageMetadata, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
@@ -329,11 +329,10 @@ var cmdThings = []cobra.Command{
return
}
pm := mgxsdk.PageMetadata{
Offset: Offset,
Limit: Limit,
DomainID: args[1],
Offset: Offset,
Limit: Limit,
}
ul, err := sdk.ListThingUsers(args[0], pm, args[2])
ul, err := sdk.ListThingUsers(args[0], pm, args[1], args[2])
if err != nil {
logErrorCmd(*cmd, err)
return
+2 -2
View File
@@ -220,7 +220,7 @@ func TestGetThingsCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("Things", mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("Things", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall1 := sdkMock.On("Thing", mock.Anything, mock.Anything, mock.Anything).Return(tc.thing, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...)
@@ -779,7 +779,7 @@ func TestUsersThingCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("ListThingUsers", mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
sdkCall := sdkMock.On("ListThingUsers", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr)
out := executeCommand(t, rootCmd, append([]string{usrCmd}, tc.args...)...)
switch tc.logType {
+5 -12
View File
@@ -91,13 +91,13 @@ var cmdUsers = []cobra.Command{
},
},
{
Use: "token <username> <password> [<domainID>]",
Use: "token <username> <password>",
Short: "Get token",
Long: "Generate new token from username and password\n" +
"For example:\n" +
"\tmagistrala-cli users token user@example.com 12345678\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 3 && len(args) != 2 {
if len(args) != 2 {
logUsageCmd(*cmd, cmd.Use)
return
}
@@ -106,9 +106,6 @@ var cmdUsers = []cobra.Command{
Identity: args[0],
Secret: args[1],
}
if len(args) == 3 {
lg.DomainID = args[2]
}
token, err := sdk.CreateToken(lg)
if err != nil {
@@ -120,22 +117,18 @@ var cmdUsers = []cobra.Command{
},
},
{
Use: "refreshtoken <token> [<domainID>]",
Use: "refreshtoken <token>",
Short: "Get token",
Long: "Generate new token from refresh token\n" +
"For example:\n" +
"\tmagistrala-cli users refreshtoken <refresh_token>\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 2 && len(args) != 1 {
if len(args) != 1 {
logUsageCmd(*cmd, cmd.Use)
return
}
lg := mgxsdk.Login{}
if len(args) == 2 {
lg.DomainID = args[1]
}
token, err := sdk.RefreshToken(lg, args[0])
token, err := sdk.RefreshToken(args[0])
if err != nil {
logErrorCmd(*cmd, err)
return
+13 -58
View File
@@ -279,7 +279,6 @@ func TestIssueTokenCmd(t *testing.T) {
rootCmd := setFlags(usersCmd)
var tkn mgsdk.Token
domainID := testsutil.GenerateUUID(t)
invalidPassword := ""
token := mgsdk.Token{
@@ -296,7 +295,7 @@ func TestIssueTokenCmd(t *testing.T) {
logType outputLog
}{
{
desc: "issue token successfully without domain id",
desc: "issue token successfully",
args: []string{
user.Credentials.Identity,
user.Credentials.Secret,
@@ -305,17 +304,6 @@ func TestIssueTokenCmd(t *testing.T) {
logType: entityLog,
token: token,
},
{
desc: "issue token successfully with domain id",
args: []string{
user.Credentials.Identity,
user.Credentials.Secret,
domainID,
},
sdkerr: nil,
logType: entityLog,
token: token,
},
{
desc: "issue token with failed authentication",
args: []string{
@@ -331,6 +319,8 @@ func TestIssueTokenCmd(t *testing.T) {
desc: "issue token with invalid args",
args: []string{
user.Credentials.Identity,
user.Credentials.Secret,
extraArg,
},
logType: usageLog,
},
@@ -338,22 +328,12 @@ func TestIssueTokenCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("CreateToken", mock.Anything).Return(tc.token, tc.sdkerr)
switch len(tc.args) {
case 2:
lg := mgsdk.Login{
Identity: tc.args[0],
Secret: tc.args[1],
}
sdkCall = sdkMock.On("CreateToken", lg).Return(tc.token, tc.sdkerr)
case 3:
lg := mgsdk.Login{
Identity: tc.args[0],
Secret: tc.args[1],
DomainID: tc.args[2],
}
sdkCall = sdkMock.On("CreateToken", lg).Return(tc.token, tc.sdkerr)
lg := mgsdk.Login{
Identity: tc.args[0],
Secret: tc.args[1],
}
sdkCall := sdkMock.On("CreateToken", lg).Return(tc.token, tc.sdkerr)
out := executeCommand(t, rootCmd, append([]string{tokCmd}, tc.args...)...)
switch tc.logType {
@@ -379,8 +359,6 @@ func TestRefreshIssueTokenCmd(t *testing.T) {
rootCmd := setFlags(usersCmd)
var tkn mgsdk.Token
domainID := testsutil.GenerateUUID(t)
invalidIdentity := "invalidIdentity"
token := mgsdk.Token{
AccessToken: testsutil.GenerateUUID(t),
@@ -398,17 +376,7 @@ func TestRefreshIssueTokenCmd(t *testing.T) {
{
desc: "issue refresh token successfully without domain id",
args: []string{
user.Credentials.Identity,
},
sdkerr: nil,
logType: entityLog,
token: token,
},
{
desc: "issue refresh token successfully with domain id",
args: []string{
user.Credentials.Identity,
domainID,
"token",
},
sdkerr: nil,
logType: entityLog,
@@ -417,8 +385,7 @@ func TestRefreshIssueTokenCmd(t *testing.T) {
{
desc: "issue refresh token with invalid args",
args: []string{
user.Credentials.Identity,
domainID,
"token",
extraArg,
},
logType: usageLog,
@@ -426,7 +393,7 @@ func TestRefreshIssueTokenCmd(t *testing.T) {
{
desc: "issue refresh token with invalid identity",
args: []string{
invalidIdentity,
"invalidToken",
},
sdkerr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden).Error()),
@@ -437,20 +404,8 @@ func TestRefreshIssueTokenCmd(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
sdkCall := sdkMock.On("RefreshToken", mock.Anything, mock.Anything).Return(tc.token, tc.sdkerr)
switch len(tc.args) {
case 1:
lg := mgsdk.Login{
Identity: tc.args[0],
}
sdkCall = sdkMock.On("RefreshToken", lg).Return(tc.token, tc.sdkerr)
case 2:
lg := mgsdk.Login{
Identity: tc.args[0],
DomainID: tc.args[1],
}
sdkCall = sdkMock.On("RefreshToken", lg).Return(tc.token, tc.sdkerr)
}
sdkCall := sdkMock.On("RefreshToken", mock.Anything).Return(tc.token, tc.sdkerr)
out := executeCommand(t, rootCmd, append([]string{refTokCmd}, tc.args...)...)
switch tc.logType {
+1 -1
View File
@@ -339,7 +339,7 @@ func createAdmin(ctx context.Context, c config, crepo clientspg.Repository, hsr
if _, err = crepo.Save(ctx, client); err != nil {
return "", err
}
if _, err = svc.IssueToken(ctx, c.AdminEmail, c.AdminPassword, ""); err != nil {
if _, err = svc.IssueToken(ctx, c.AdminEmail, c.AdminPassword); err != nil {
return "", err
}
return client.ID, nil
+9 -32
View File
@@ -16,7 +16,7 @@ type sessionKeyType string
const SessionKey = sessionKeyType("session")
func AuthenticateMiddleware(authn mgauthn.Authentication) func(http.Handler) http.Handler {
func AuthenticateMiddleware(authn mgauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := apiutil.ExtractBearerToken(r)
@@ -31,37 +31,14 @@ func AuthenticateMiddleware(authn mgauthn.Authentication) func(http.Handler) htt
return
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func AuthenticateMiddlewareDomain(authn mgauthn.Authentication) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := apiutil.ExtractBearerToken(r)
if token == "" {
EncodeError(r.Context(), apiutil.ErrBearerToken, w)
return
}
domain := chi.URLParam(r, "domainID")
if domain == "" {
EncodeError(r.Context(), apiutil.ErrMissingDomainID, w)
return
}
resp, err := authn.Authenticate(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
}
if domain != resp.DomainID {
resp = mgauthn.Session{}
EncodeError(r.Context(), apiutil.ErrValidation, w)
return
if domainCheck {
domain := chi.URLParam(r, "domainID")
if domain == "" {
EncodeError(r.Context(), apiutil.ErrMissingDomainID, w)
return
}
resp.DomainID = domain
resp.DomainUserID = domain + "_" + resp.UserID
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
+221 -199
View File
@@ -58,7 +58,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
{
desc: "successfully with groups kind",
kind: policies.NewGroupKind,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: createGroupReq{
Group: groups.Group{
Name: valid,
@@ -72,7 +72,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
{
desc: "successfully with channels kind",
kind: policies.NewChannelKind,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: createGroupReq{
Group: groups.Group{
Name: valid,
@@ -98,7 +98,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
{
desc: "unsuccessfully with invalid request",
kind: policies.NewGroupKind,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: createGroupReq{
Group: groups.Group{},
},
@@ -108,7 +108,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
{
desc: "unsuccessfully with repo error",
kind: policies.NewGroupKind,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: createGroupReq{
Group: groups.Group{
Name: valid,
@@ -122,22 +122,24 @@ func TestCreateGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("CreateGroup", ctx, tc.session, tc.kind, tc.req.Group).Return(tc.svcResp, tc.svcErr)
resp, err := CreateGroupEndpoint(svc, tc.kind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(createGroupRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
assert.Equal(t, response.Headers()["Location"], fmt.Sprintf("/groups/%s", response.ID))
default:
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
}
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("CreateGroup", ctx, tc.session, tc.kind, tc.req.Group).Return(tc.svcResp, tc.svcErr)
resp, err := CreateGroupEndpoint(svc, tc.kind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(createGroupRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
assert.Equal(t, response.Headers()["Location"], fmt.Sprintf("/groups/%s", response.ID))
default:
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
}
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -154,7 +156,7 @@ func TestViewGroupEndpoint(t *testing.T) {
}{
{
desc: "successfully",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: groupReq{
id: testsutil.GenerateUUID(t),
},
@@ -175,7 +177,7 @@ func TestViewGroupEndpoint(t *testing.T) {
},
{
desc: "unsuccessfully with invalid request",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: groupReq{},
svcResp: groups.Group{},
svcErr: nil,
@@ -184,7 +186,7 @@ func TestViewGroupEndpoint(t *testing.T) {
},
{
desc: "unsuccessfully with repo error",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: groupReq{
id: testsutil.GenerateUUID(t),
},
@@ -196,16 +198,18 @@ func TestViewGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("ViewGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := ViewGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(viewGroupRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("ViewGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := ViewGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(viewGroupRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -225,7 +229,7 @@ func TestViewGroupPermsEndpoint(t *testing.T) {
req: groupPermsReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: []string{
valid,
},
@@ -244,13 +248,13 @@ func TestViewGroupPermsEndpoint(t *testing.T) {
{
desc: "unsuccessfully with invalid request",
req: groupPermsReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: viewGroupPermsRes{},
err: apiutil.ErrValidation,
},
{
desc: "unsuccessfully with repo error",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: groupPermsReq{
id: testsutil.GenerateUUID(t),
},
@@ -262,16 +266,18 @@ func TestViewGroupPermsEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("ViewGroupPerms", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := ViewGroupPermsEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(viewGroupPermsRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("ViewGroupPerms", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := ViewGroupPermsEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(viewGroupPermsRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -291,7 +297,7 @@ func TestEnableGroupEndpoint(t *testing.T) {
req: changeGroupStatusReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: validGroupResp,
svcErr: nil,
resp: changeStatusRes{Group: validGroupResp},
@@ -307,7 +313,7 @@ func TestEnableGroupEndpoint(t *testing.T) {
},
{
desc: "unsuccessfully with invalid request",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: changeGroupStatusReq{},
resp: changeStatusRes{},
err: apiutil.ErrValidation,
@@ -317,7 +323,7 @@ func TestEnableGroupEndpoint(t *testing.T) {
req: changeGroupStatusReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Group{},
svcErr: svcerr.ErrAuthorization,
resp: changeStatusRes{},
@@ -326,16 +332,18 @@ func TestEnableGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("EnableGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := EnableGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(changeStatusRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("EnableGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := EnableGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(changeStatusRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -355,7 +363,7 @@ func TestDisableGroupEndpoint(t *testing.T) {
req: changeGroupStatusReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: validGroupResp,
svcErr: nil,
resp: changeStatusRes{Group: validGroupResp},
@@ -371,7 +379,7 @@ func TestDisableGroupEndpoint(t *testing.T) {
},
{
desc: "unsuccessfully with invalid request",
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: changeGroupStatusReq{},
resp: changeStatusRes{},
err: apiutil.ErrValidation,
@@ -381,7 +389,7 @@ func TestDisableGroupEndpoint(t *testing.T) {
req: changeGroupStatusReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Group{},
svcErr: svcerr.ErrAuthorization,
resp: changeStatusRes{},
@@ -390,16 +398,18 @@ func TestDisableGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("DisableGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := DisableGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(changeStatusRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("DisableGroup", ctx, tc.session, tc.req.id).Return(tc.svcResp, tc.svcErr)
resp, err := DisableGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(changeStatusRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -418,7 +428,7 @@ func TestDeleteGroupEndpoint(t *testing.T) {
req: groupReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: deleteGroupRes{deleted: true},
err: nil,
@@ -434,7 +444,7 @@ func TestDeleteGroupEndpoint(t *testing.T) {
{
desc: "unsuccessfully with invalid request",
req: groupReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: deleteGroupRes{},
err: apiutil.ErrValidation,
},
@@ -443,7 +453,7 @@ func TestDeleteGroupEndpoint(t *testing.T) {
req: groupReq{
id: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: svcerr.ErrAuthorization,
resp: deleteGroupRes{},
err: svcerr.ErrAuthorization,
@@ -451,21 +461,23 @@ func TestDeleteGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("DeleteGroup", ctx, tc.session, tc.req.id).Return(tc.svcErr)
resp, err := DeleteGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(deleteGroupRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusNoContent)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
svcCall := svc.On("DeleteGroup", ctx, tc.session, tc.req.id).Return(tc.svcErr)
resp, err := DeleteGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(deleteGroupRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusNoContent)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -486,7 +498,7 @@ func TestUpdateGroupEndpoint(t *testing.T) {
id: testsutil.GenerateUUID(t),
Name: valid,
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: validGroupResp,
svcErr: nil,
resp: updateGroupRes{Group: validGroupResp},
@@ -504,7 +516,7 @@ func TestUpdateGroupEndpoint(t *testing.T) {
{
desc: "unsuccessfully with invalid request",
req: updateGroupReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: updateGroupRes{},
err: apiutil.ErrValidation,
},
@@ -514,7 +526,7 @@ func TestUpdateGroupEndpoint(t *testing.T) {
id: testsutil.GenerateUUID(t),
Name: valid,
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Group{},
svcErr: svcerr.ErrAuthorization,
resp: updateGroupRes{},
@@ -523,22 +535,24 @@ func TestUpdateGroupEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
group := groups.Group{
ID: tc.req.id,
Name: tc.req.Name,
Description: tc.req.Description,
Metadata: tc.req.Metadata,
}
svcCall := svc.On("UpdateGroup", ctx, tc.session, group).Return(tc.svcResp, tc.svcErr)
resp, err := UpdateGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(updateGroupRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
group := groups.Group{
ID: tc.req.id,
Name: tc.req.Name,
Description: tc.req.Description,
Metadata: tc.req.Metadata,
}
svcCall := svc.On("UpdateGroup", ctx, tc.session, group).Return(tc.svcResp, tc.svcErr)
resp, err := UpdateGroupEndpoint(svc)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(updateGroupRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -601,7 +615,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{
Groups: []groups.Group{validGroupResp},
},
@@ -626,7 +640,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{
Groups: []groups.Group{validGroupResp},
},
@@ -653,7 +667,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{
Groups: []groups.Group{validGroupResp, childGroup},
},
@@ -682,7 +696,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.UsersKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{
Groups: []groups.Group{validGroupResp, childGroup},
},
@@ -711,7 +725,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.UsersKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{
Groups: []groups.Group{parentGroup, validGroupResp},
},
@@ -728,7 +742,7 @@ func TestListGroupsEndpoint(t *testing.T) {
{
desc: "unsuccessfully with invalid request",
memberKind: policies.ThingsKind,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
req: listGroupsReq{},
resp: groupPageRes{},
err: apiutil.ErrValidation,
@@ -745,7 +759,7 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.Page{},
svcErr: svcerr.ErrAuthorization,
resp: groupPageRes{},
@@ -777,26 +791,28 @@ func TestListGroupsEndpoint(t *testing.T) {
memberKind: "",
memberID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: groupPageRes{},
err: apiutil.ErrValidation,
},
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.memberKind = tc.memberKind
}
svcCall := svc.On("ListGroups", ctx, tc.session, tc.req.memberKind, tc.req.memberID, tc.req.Page).Return(tc.svcResp, tc.svcErr)
resp, err := ListGroupsEndpoint(svc, mock.Anything, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(groupPageRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.memberKind = tc.memberKind
}
svcCall := svc.On("ListGroups", ctx, tc.session, tc.req.memberKind, tc.req.memberID, tc.req.Page).Return(tc.svcResp, tc.svcErr)
resp, err := ListGroupsEndpoint(svc, mock.Anything, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(groupPageRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -819,7 +835,7 @@ func TestListMembersEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
groupID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.MembersPage{
Members: []groups.Member{
{
@@ -845,7 +861,7 @@ func TestListMembersEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
groupID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.MembersPage{
Members: []groups.Member{
{
@@ -869,7 +885,7 @@ func TestListMembersEndpoint(t *testing.T) {
desc: "unsuccessfully with invalid request",
memberKind: policies.ThingsKind,
req: listMembersReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: listMembersRes{},
err: apiutil.ErrValidation,
},
@@ -880,7 +896,7 @@ func TestListMembersEndpoint(t *testing.T) {
memberKind: policies.ThingsKind,
groupID: testsutil.GenerateUUID(t),
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcResp: groups.MembersPage{},
svcErr: svcerr.ErrAuthorization,
resp: listMembersRes{},
@@ -899,19 +915,21 @@ func TestListMembersEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.memberKind = tc.memberKind
}
svcCall := svc.On("ListMembers", ctx, tc.session, tc.req.groupID, tc.req.permission, tc.req.memberKind).Return(tc.svcResp, tc.svcErr)
resp, err := ListMembersEndpoint(svc, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(listMembersRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.memberKind = tc.memberKind
}
svcCall := svc.On("ListMembers", ctx, tc.session, tc.req.groupID, tc.req.permission, tc.req.memberKind).Return(tc.svcResp, tc.svcErr)
resp, err := ListMembersEndpoint(svc, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(listMembersRes)
assert.Equal(t, response.Code(), http.StatusOK)
assert.Empty(t, response.Headers())
assert.False(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -939,7 +957,7 @@ func TestAssignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: assignRes{assigned: true},
err: nil,
@@ -955,7 +973,7 @@ func TestAssignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: assignRes{assigned: true},
err: nil,
@@ -971,7 +989,7 @@ func TestAssignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: assignRes{assigned: true},
err: nil,
@@ -981,7 +999,7 @@ func TestAssignMembersEndpoint(t *testing.T) {
relation: policies.ContributorRelation,
memberKind: policies.ThingsKind,
req: assignReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: assignRes{},
err: apiutil.ErrValidation,
},
@@ -997,7 +1015,7 @@ func TestAssignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: svcerr.ErrAuthorization,
resp: assignRes{},
err: svcerr.ErrAuthorization,
@@ -1020,27 +1038,29 @@ func TestAssignMembersEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.MemberKind = tc.memberKind
}
if tc.relation != "" {
tc.req.Relation = tc.relation
}
svcCall := svc.On("Assign", ctx, tc.session, tc.req.groupID, tc.req.Relation, tc.req.MemberKind, tc.req.Members).Return(tc.svcErr)
resp, err := AssignMembersEndpoint(svc, tc.relation, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(assignRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.MemberKind = tc.memberKind
}
if tc.relation != "" {
tc.req.Relation = tc.relation
}
svcCall := svc.On("Assign", ctx, tc.session, tc.req.groupID, tc.req.Relation, tc.req.MemberKind, tc.req.Members).Return(tc.svcErr)
resp, err := AssignMembersEndpoint(svc, tc.relation, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(assignRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
})
}
}
@@ -1068,7 +1088,7 @@ func TestUnassignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: unassignRes{unassigned: true},
err: nil,
@@ -1084,7 +1104,7 @@ func TestUnassignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: nil,
resp: unassignRes{unassigned: true},
err: nil,
@@ -1101,7 +1121,7 @@ func TestUnassignMembersEndpoint(t *testing.T) {
},
},
svcErr: nil,
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: unassignRes{unassigned: true},
err: nil,
},
@@ -1110,7 +1130,7 @@ func TestUnassignMembersEndpoint(t *testing.T) {
relation: policies.ContributorRelation,
memberKind: policies.ThingsKind,
req: unassignReq{},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
resp: unassignRes{},
err: apiutil.ErrValidation,
},
@@ -1126,7 +1146,7 @@ func TestUnassignMembersEndpoint(t *testing.T) {
testsutil.GenerateUUID(t),
},
},
session: mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
session: mgauthn.Session{DomainUserID: validID + "_" + validID, UserID: validID, DomainID: validID},
svcErr: svcerr.ErrAuthorization,
resp: unassignRes{},
err: svcerr.ErrAuthorization,
@@ -1149,26 +1169,28 @@ func TestUnassignMembersEndpoint(t *testing.T) {
}
for _, tc := range cases {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.MemberKind = tc.memberKind
}
if tc.relation != "" {
tc.req.Relation = tc.relation
}
svcCall := svc.On("Unassign", ctx, tc.session, tc.req.groupID, tc.req.Relation, tc.req.MemberKind, tc.req.Members).Return(tc.svcErr)
resp, err := UnassignMembersEndpoint(svc, tc.relation, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(unassignRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.WithValue(context.Background(), api.SessionKey, tc.session)
if tc.memberKind != "" {
tc.req.MemberKind = tc.memberKind
}
if tc.relation != "" {
tc.req.Relation = tc.relation
}
svcCall := svc.On("Unassign", ctx, tc.session, tc.req.groupID, tc.req.Relation, tc.req.MemberKind, tc.req.Members).Return(tc.svcErr)
resp, err := UnassignMembersEndpoint(svc, tc.relation, tc.memberKind)(ctx, tc.req)
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err))
response := resp.(unassignRes)
switch err {
case nil:
assert.Equal(t, response.Code(), http.StatusCreated)
default:
assert.Equal(t, response.Code(), http.StatusBadRequest)
}
assert.Empty(t, response.Headers())
assert.True(t, response.Empty())
svcCall.Unset()
})
}
}
+2
View File
@@ -49,6 +49,7 @@ func ViewGroupEndpoint(svc groups.Service) endpoint.Endpoint {
if !ok {
return viewGroupRes{}, svcerr.ErrAuthorization
}
group, err := svc.ViewGroup(ctx, session, req.id)
if err != nil {
return viewGroupRes{}, err
@@ -278,6 +279,7 @@ func DeleteGroupEndpoint(svc groups.Service) endpoint.Endpoint {
if !ok {
return deleteGroupRes{}, svcerr.ErrAuthorization
}
if err := svc.DeleteGroup(ctx, session, req.id); err != nil {
return deleteGroupRes{}, err
}
+6 -5
View File
@@ -32,7 +32,7 @@ func sendInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
invitation := invitations.Invitation{
UserID: req.UserID,
DomainID: req.DomainID,
DomainID: session.DomainID,
Relation: req.Relation,
Resend: req.Resend,
}
@@ -53,11 +53,11 @@ func viewInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
req.domainID = session.DomainID
invitation, err := svc.ViewInvitation(ctx, session, req.userID, req.domainID)
if err != nil {
@@ -82,6 +82,7 @@ func listInvitationsEndpoint(svc invitations.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthorization
}
req.Page.DomainID = session.DomainID
page, err := svc.ListInvitations(ctx, session, req.Page)
if err != nil {
return nil, err
@@ -104,7 +105,7 @@ func acceptInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
if !ok {
return nil, svcerr.ErrAuthorization
}
req.domainID = session.DomainID
if err := svc.AcceptInvitation(ctx, session, req.domainID); err != nil {
return nil, err
}
@@ -124,7 +125,7 @@ func rejectInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
if !ok {
return nil, svcerr.ErrAuthorization
}
req.domainID = session.DomainID
if err := svc.RejectInvitation(ctx, session, req.domainID); err != nil {
return nil, err
}
@@ -144,7 +145,7 @@ func deleteInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
if !ok {
return nil, svcerr.ErrAuthorization
}
req.domainID = session.DomainID
if err := svc.DeleteInvitation(ctx, session, req.userID, req.domainID); err != nil {
return nil, err
}
+35 -13
View File
@@ -84,7 +84,7 @@ func TestSendInvitation(t *testing.T) {
domainID: domainID,
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s", "relation": "%s"}`, validID, "domain"),
authnRes: mgauthn.Session{UserID: validID, DomainID: validID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusCreated,
contentType: validContenType,
svcErr: nil,
@@ -112,6 +112,7 @@ func TestSendInvitation(t *testing.T) {
domainID: domainID,
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s", "relation": "%s"}`, validID, "domain"),
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
@@ -121,6 +122,7 @@ func TestSendInvitation(t *testing.T) {
domainID: domainID,
token: validToken,
data: `data`,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
@@ -130,6 +132,7 @@ func TestSendInvitation(t *testing.T) {
domainID: domainID,
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s", "relation": "%s"}`, validID, "domain"),
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
@@ -175,6 +178,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "valid request",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusOK,
contentType: validContenType,
@@ -191,6 +195,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with offset",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "offset=1",
status: http.StatusOK,
@@ -209,6 +214,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with limit",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "limit=1",
status: http.StatusOK,
@@ -227,6 +233,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with user_id",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("user_id=%s", validID),
status: http.StatusOK,
@@ -236,6 +243,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with duplicate user_id",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "user_id=1&user_id=2",
status: http.StatusBadRequest,
@@ -245,6 +253,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with invited_by",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("invited_by=%s", validID),
status: http.StatusOK,
@@ -254,6 +263,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with duplicate invited_by",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "invited_by=1&invited_by=2",
status: http.StatusBadRequest,
@@ -263,6 +273,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with relation",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("relation=%s", "relation"),
status: http.StatusOK,
@@ -272,6 +283,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with duplicate relation",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "relation=1&relation=2",
status: http.StatusBadRequest,
@@ -289,6 +301,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with state",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "state=pending",
status: http.StatusOK,
@@ -316,6 +329,7 @@ func TestListInvitation(t *testing.T) {
{
desc: "with service error",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusForbidden,
contentType: validContenType,
@@ -359,9 +373,10 @@ func TestViewInvitation(t *testing.T) {
}{
{
desc: "valid request",
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: validID,
domainID: domainID,
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
@@ -370,25 +385,26 @@ func TestViewInvitation(t *testing.T) {
desc: "invalid token",
token: "",
userID: validID,
domainID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: validID,
status: http.StatusForbidden,
domainID: domainID,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
svcErr: svcerr.ErrViewEntity,
},
{
desc: "with empty user_id",
token: validToken,
userID: "",
domainID: validID,
domainID: domainID,
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
@@ -407,7 +423,7 @@ func TestViewInvitation(t *testing.T) {
token: validToken,
userID: "",
domainID: "",
status: http.StatusNotFound,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
@@ -451,9 +467,10 @@ func TestDeleteInvitation(t *testing.T) {
}{
{
desc: "valid request",
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: validID,
domainID: domainID,
status: http.StatusNoContent,
contentType: validContenType,
svcErr: nil,
@@ -462,16 +479,17 @@ func TestDeleteInvitation(t *testing.T) {
desc: "invalid token",
token: "",
userID: validID,
domainID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: validID,
domainID: domainID,
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
@@ -480,7 +498,7 @@ func TestDeleteInvitation(t *testing.T) {
desc: "with empty user_id",
token: validToken,
userID: "",
domainID: validID,
domainID: domainID,
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
@@ -499,7 +517,7 @@ func TestDeleteInvitation(t *testing.T) {
token: validToken,
userID: "",
domainID: "",
status: http.StatusNotFound,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
@@ -543,6 +561,7 @@ func TestAcceptInvitation(t *testing.T) {
{
desc: "valid request",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusNoContent,
contentType: validContenType,
@@ -559,6 +578,7 @@ func TestAcceptInvitation(t *testing.T) {
{
desc: "with service error",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusForbidden,
contentType: validContenType,
@@ -620,6 +640,7 @@ func TestRejectInvitation(t *testing.T) {
{
desc: "valid request",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusNoContent,
contentType: validContenType,
@@ -636,6 +657,7 @@ func TestRejectInvitation(t *testing.T) {
{
desc: "unauthorized error",
domainID: domainID,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusForbidden,
contentType: validContenType,
+1 -31
View File
@@ -11,23 +11,16 @@ import (
const maxLimitSize = 100
type sendInvitationReq struct {
token string
DomainID string `json:"domain_id,omitempty"`
domainID string
UserID string `json:"user_id,omitempty"`
Relation string `json:"relation,omitempty"`
Resend bool `json:"resend,omitempty"`
}
func (req *sendInvitationReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.UserID == "" {
return apiutil.ErrMissingID
}
if req.DomainID == "" {
return apiutil.ErrMissingDomainID
}
if err := invitations.CheckRelation(req.Relation); err != nil {
return err
}
@@ -36,18 +29,10 @@ func (req *sendInvitationReq) validate() error {
}
type listInvitationsReq struct {
token string
invitations.Page
}
func (req *listInvitationsReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.Page.DomainID == "" {
return apiutil.ErrMissingDomainID
}
if req.Page.Limit > maxLimitSize || req.Page.Limit < 1 {
return apiutil.ErrLimitSize
}
@@ -56,37 +41,22 @@ func (req *listInvitationsReq) validate() error {
}
type acceptInvitationReq struct {
token string
domainID string
}
func (req *acceptInvitationReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.domainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
type invitationReq struct {
token string
userID string
domainID string
}
func (req *invitationReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.userID == "" {
return apiutil.ErrMissingID
}
if req.domainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
+4 -87
View File
@@ -24,53 +24,28 @@ func TestSendInvitationReqValidation(t *testing.T) {
{
desc: "valid request",
req: sendInvitationReq{
token: valid,
UserID: valid,
DomainID: valid,
domainID: valid,
Relation: policies.DomainRelation,
Resend: true,
},
err: nil,
},
{
desc: "empty token",
req: sendInvitationReq{
token: "",
UserID: valid,
DomainID: valid,
Relation: policies.DomainRelation,
Resend: true,
},
err: apiutil.ErrBearerToken,
},
{
desc: "empty user ID",
req: sendInvitationReq{
token: valid,
UserID: "",
DomainID: valid,
domainID: valid,
Relation: policies.DomainRelation,
Resend: true,
},
err: apiutil.ErrMissingID,
},
{
desc: "empty domain_id",
req: sendInvitationReq{
token: valid,
UserID: valid,
DomainID: "",
Relation: policies.DomainRelation,
Resend: true,
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "missing relation",
req: sendInvitationReq{
token: valid,
UserID: valid,
DomainID: valid,
domainID: valid,
Relation: "",
Resend: true,
},
@@ -79,9 +54,8 @@ func TestSendInvitationReqValidation(t *testing.T) {
{
desc: "invalid relation",
req: sendInvitationReq{
token: valid,
UserID: valid,
DomainID: valid,
domainID: valid,
Relation: "invalid",
Resend: true,
},
@@ -106,7 +80,6 @@ func TestListInvitationsReq(t *testing.T) {
{
desc: "valid request",
req: listInvitationsReq{
token: valid,
Page: invitations.Page{
Limit: 1,
DomainID: valid,
@@ -114,29 +87,9 @@ func TestListInvitationsReq(t *testing.T) {
},
err: nil,
},
{
desc: "empty domainID",
req: listInvitationsReq{
token: valid,
Page: invitations.Page{Limit: 1},
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "empty token",
req: listInvitationsReq{
token: "",
Page: invitations.Page{
Limit: 1,
DomainID: valid,
},
},
err: apiutil.ErrBearerToken,
},
{
desc: "invalid limit",
req: listInvitationsReq{
token: valid,
Page: invitations.Page{
Limit: 1000,
DomainID: valid,
@@ -163,26 +116,10 @@ func TestAcceptInvitationReq(t *testing.T) {
{
desc: "valid request",
req: acceptInvitationReq{
token: valid,
domainID: valid,
},
err: nil,
},
{
desc: "empty token",
req: acceptInvitationReq{
token: "",
},
err: apiutil.ErrBearerToken,
},
{
desc: "empty domain_id",
req: acceptInvitationReq{
token: valid,
domainID: "",
},
err: apiutil.ErrMissingDomainID,
},
}
for _, tc := range cases {
@@ -202,39 +139,19 @@ func TestInvitationReqValidation(t *testing.T) {
{
desc: "valid request",
req: invitationReq{
token: valid,
userID: valid,
domainID: valid,
},
err: nil,
},
{
desc: "empty token",
req: invitationReq{
token: "",
userID: valid,
domainID: valid,
},
err: apiutil.ErrBearerToken,
},
{
desc: "empty user ID",
req: invitationReq{
token: valid,
userID: "",
domainID: valid,
},
err: apiutil.ErrMissingID,
},
{
desc: "empty domain",
req: invitationReq{
token: valid,
userID: valid,
domainID: "",
},
err: apiutil.ErrMissingDomainID,
},
}
for _, tc := range cases {
+2 -9
View File
@@ -38,7 +38,7 @@ func MakeHandler(svc invitations.Service, logger *slog.Logger, authn mgauthn.Aut
mux := chi.NewRouter()
mux.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/{domainID}/invitations", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
@@ -94,11 +94,9 @@ func decodeSendInvitationReq(_ context.Context, r *http.Request) (interface{}, e
}
var req sendInvitationReq
req.DomainID = chi.URLParam(r, "domainID")
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
req.token = apiutil.ExtractBearerToken(r)
return req, nil
}
@@ -134,14 +132,12 @@ func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{},
}
req := listInvitationsReq{
token: apiutil.ExtractBearerToken(r),
Page: invitations.Page{
Offset: offset,
Limit: limit,
InvitedBy: invitedBy,
UserID: userID,
Relation: relation,
DomainID: chi.URLParam(r, "domainID"),
State: state,
},
}
@@ -155,16 +151,13 @@ func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (interface{},
}
return acceptInvitationReq{
token: apiutil.ExtractBearerToken(r),
domainID: chi.URLParam(r, "domainID"),
}, nil
}
func decodeInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
req := invitationReq{
token: apiutil.ExtractBearerToken(r),
userID: chi.URLParam(r, "user_id"),
domainID: chi.URLParam(r, "domainID"),
userID: chi.URLParam(r, "user_id"),
}
return req, nil
+1 -1
View File
@@ -35,7 +35,7 @@ func (svc *service) SendInvitation(ctx context.Context, session authn.Session, i
invitation.InvitedBy = session.UserID
joinToken, err := svc.token.Issue(ctx, &magistrala.IssueReq{UserId: session.UserID, DomainId: &invitation.DomainID, Type: uint32(auth.InvitationKey)})
joinToken, err := svc.token.Issue(ctx, &magistrala.IssueReq{UserId: session.UserID, Type: uint32(auth.InvitationKey)})
if err != nil {
return err
}
+2 -2
View File
@@ -113,8 +113,8 @@ func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (stri
return id, nil
}
func (sdk mgSDK) Bootstraps(pm PageMetadata, token string) (BootstrapPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, configsEndpoint)
func (sdk mgSDK) Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", domainID, configsEndpoint)
url, err := sdk.withQueryParams(sdk.bootstrapURL, endpoint, pm)
if err != nil {
return BootstrapPage{}, errors.NewSDKError(err)
+60 -57
View File
@@ -175,14 +175,14 @@ func TestAddBootstrap(t *testing.T) {
err: nil,
},
{
desc: "add with invalid token",
domainID: domainID,
token: invalidToken,
cfg: sdkBootstrapConfig,
svcReq: bootstrapConfig,
svcRes: bootstrap.Config{},
svcErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
desc: "add with invalid token",
domainID: domainID,
token: invalidToken,
cfg: sdkBootstrapConfig,
svcReq: bootstrapConfig,
svcRes: bootstrap.Config{},
authenticateErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "add with config that cannot be marshalled",
@@ -242,7 +242,7 @@ func TestAddBootstrap(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("Add", mock.Anything, tc.session, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -295,6 +295,7 @@ func TestListBootstraps(t *testing.T) {
cases := []struct {
desc string
domainID string
token string
session mgauthn.Session
pageMeta sdk.PageMetadata
@@ -305,12 +306,12 @@ func TestListBootstraps(t *testing.T) {
err errors.SDKError
}{
{
desc: "list successfully",
token: validToken,
desc: "list successfully",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcResp: bootstrap.ConfigsPage{
Total: 1,
@@ -326,25 +327,25 @@ func TestListBootstraps(t *testing.T) {
err: nil,
},
{
desc: "list with invalid token",
token: invalidToken,
desc: "list with invalid token",
domainID: domainID,
token: invalidToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcResp: bootstrap.ConfigsPage{},
svcErr: svcerr.ErrAuthentication,
response: sdk.BootstrapPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
svcResp: bootstrap.ConfigsPage{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.BootstrapPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list with empty token",
token: "",
desc: "list with empty token",
domainID: domainID,
token: "",
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcResp: bootstrap.ConfigsPage{},
svcErr: nil,
@@ -352,12 +353,12 @@ func TestListBootstraps(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "list with invalid query params",
token: validToken,
desc: "list with invalid query params",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 1,
Limit: 10,
DomainID: domainID,
Offset: 1,
Limit: 10,
Metadata: map[string]interface{}{
"test": make(chan int),
},
@@ -368,12 +369,12 @@ func TestListBootstraps(t *testing.T) {
err: errors.NewSDKError(errMarshalChan),
},
{
desc: "list with response that cannot be unmarshalled",
token: validToken,
desc: "list with response that cannot be unmarshalled",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcResp: bootstrap.ConfigsPage{
Total: 1,
@@ -388,11 +389,11 @@ func TestListBootstraps(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("List", mock.Anything, tc.session, mock.Anything, tc.pageMeta.Offset, tc.pageMeta.Limit).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.Bootstraps(tc.pageMeta, tc.token)
resp, err := mgsdk.Bootstraps(tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if err == nil {
@@ -493,7 +494,7 @@ func TestWhiteList(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("ChangeState", mock.Anything, tc.session, tc.token, tc.thingID, tc.svcReq).Return(tc.svcErr)
@@ -614,7 +615,7 @@ func TestViewBootstrap(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("View", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr)
@@ -777,7 +778,7 @@ func TestUpdateBootstrap(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticationErr)
svcCall := bsvc.On("Update", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr)
@@ -899,18 +900,20 @@ func TestUpdateBootstrapCerts(t *testing.T) {
},
}
for _, tc := range cases {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("UpdateCert", mock.Anything, tc.session, tc.id, tc.clientCert, tc.clientKey, tc.caCert).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.UpdateBootstrapCerts(tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
if err == nil {
assert.Equal(t, tc.response, resp)
}
svcCall.Unset()
authCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("UpdateCert", mock.Anything, tc.session, tc.id, tc.clientCert, tc.clientKey, tc.caCert).Return(tc.svcResp, tc.svcErr)
resp, err := mgsdk.UpdateBootstrapCerts(tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
if err == nil {
assert.Equal(t, tc.response, resp)
}
svcCall.Unset()
authCall.Unset()
})
}
}
@@ -1002,7 +1005,7 @@ func TestUpdateBootstrapConnection(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels).Return(tc.svcErr)
@@ -1089,7 +1092,7 @@ func TestRemoveBootstrap(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := bsvc.On("Remove", mock.Anything, tc.session, tc.id).Return(tc.svcErr)
+6 -6
View File
@@ -52,8 +52,8 @@ func (sdk mgSDK) CreateChannel(c Channel, domainID, token string) (Channel, erro
return c, nil
}
func (sdk mgSDK) Channels(pm PageMetadata, token string) (ChannelsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, channelsEndpoint)
func (sdk mgSDK) Channels(pm PageMetadata, domainID, token string) (ChannelsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", domainID, channelsEndpoint)
url, err := sdk.withQueryParams(sdk.thingsURL, endpoint, pm)
if err != nil {
return ChannelsPage{}, errors.NewSDKError(err)
@@ -174,8 +174,8 @@ func (sdk mgSDK) RemoveUserFromChannel(channelID string, req UsersRelationReques
return sdkerr
}
func (sdk mgSDK) ListChannelUsers(channelID string, pm PageMetadata, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", pm.DomainID, channelsEndpoint, channelID, usersEndpoint), pm)
func (sdk mgSDK) ListChannelUsers(channelID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", domainID, channelsEndpoint, channelID, usersEndpoint), pm)
if err != nil {
return UsersPage{}, errors.NewSDKError(err)
}
@@ -215,8 +215,8 @@ func (sdk mgSDK) RemoveUserGroupFromChannel(channelID string, req UserGroupsRequ
return sdkerr
}
func (sdk mgSDK) ListChannelUserGroups(channelID string, pm PageMetadata, token string) (GroupsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", pm.DomainID, channelsEndpoint, channelID, groupsEndpoint), pm)
func (sdk mgSDK) ListChannelUserGroups(channelID string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", domainID, channelsEndpoint, channelID, groupsEndpoint), pm)
if err != nil {
return GroupsPage{}, errors.NewSDKError(err)
}
+76 -78
View File
@@ -238,7 +238,7 @@ func TestCreateChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("CreateGroup", mock.Anything, tc.session, channelKind, tc.createGroupReq).Return(tc.svcRes, tc.svcErr)
@@ -511,14 +511,13 @@ func TestListChannels(t *testing.T) {
Limit: tc.limit,
Level: uint64(tc.level),
Metadata: tc.metadata,
DomainID: tc.domainID,
}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.UsersKind, "", tc.groupsPageMeta).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Channels(pm, tc.token)
resp, err := mgsdk.Channels(pm, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -623,7 +622,7 @@ func TestViewChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ViewGroup", mock.Anything, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
@@ -922,7 +921,7 @@ func TestUpdateChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("UpdateGroup", mock.Anything, tc.session, tc.updateGroupReq).Return(tc.svcRes, tc.svcErr)
@@ -1167,7 +1166,7 @@ func TestListChannelsByThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.ThingsKind, tc.thingID, tc.listGroupsReq).Return(tc.svcRes, tc.svcErr)
@@ -1275,7 +1274,7 @@ func TestEnableChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("EnableGroup", mock.Anything, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
@@ -1388,7 +1387,7 @@ func TestDisableChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("DisableGroup", mock.Anything, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
@@ -1468,7 +1467,7 @@ func TestDeleteChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("DeleteGroup", mock.Anything, tc.session, tc.channelID).Return(tc.svcErr)
@@ -1562,7 +1561,7 @@ func TestChannelPermissions(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ViewGroupPerms", mock.Anything, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
@@ -1687,7 +1686,7 @@ func TestAddUserToChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Assign", mock.Anything, tc.session, tc.channelID, tc.addUserReq.Relation, policies.UsersKind, tc.addUserReq.UserIDs).Return(tc.svcErr)
@@ -1799,7 +1798,7 @@ func TestRemoveUserFromChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Unassign", mock.Anything, tc.session, tc.channelID, tc.removeUserReq.Relation, policies.UsersKind, tc.removeUserReq.UserIDs).Return(tc.svcErr)
@@ -1909,7 +1908,7 @@ func TestAddUserGroupToChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Assign", mock.Anything, tc.session, tc.channelID, relation, policies.ChannelsKind, tc.addUserGroupReq.UserGroupIDs).Return(tc.svcErr)
@@ -2019,7 +2018,7 @@ func TestRemoveUserGroupFromChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Unassign", mock.Anything, tc.session, tc.channelID, relation, policies.ChannelsKind, tc.removeUserGroupReq.UserGroupIDs).Return(tc.svcErr)
@@ -2060,6 +2059,7 @@ func TestListChannelUserGroups(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
channelID string
pageMeta sdk.PageMetadata
@@ -2072,11 +2072,10 @@ func TestListChannelUserGroups(t *testing.T) {
}{
{
desc: "list user groups successfully",
domainID: domainID,
token: validToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
DomainID: domainID,
},
pageMeta: sdk.PageMetadata{},
listGroupsReq: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
@@ -2102,12 +2101,12 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with offset and limit",
domainID: domainID,
token: validToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
Offset: 6,
Limit: nGroups,
DomainID: domainID,
Offset: 6,
Limit: nGroups,
},
listGroupsReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -2134,11 +2133,10 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with invalid token",
domainID: domainID,
token: invalidToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
DomainID: domainID,
},
pageMeta: sdk.PageMetadata{},
listGroupsReq: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
@@ -2154,11 +2152,10 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with empty token",
domainID: domainID,
token: "",
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
DomainID: domainID,
},
pageMeta: sdk.PageMetadata{},
listGroupsReq: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
@@ -2174,11 +2171,11 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with limit greater than max",
domainID: domainID,
token: validToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
Limit: 110,
DomainID: domainID,
Limit: 110,
},
listGroupsReq: groups.Page{},
svcRes: groups.Page{},
@@ -2188,6 +2185,7 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with invalid channel id",
domainID: domainID,
token: validToken,
channelID: wrongID,
pageMeta: sdk.PageMetadata{
@@ -2208,11 +2206,11 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list users groups with level exceeding max",
domainID: domainID,
token: validToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
Level: 10,
DomainID: domainID,
Level: 10,
},
listGroupsReq: groups.Page{},
svcRes: groups.Page{},
@@ -2240,11 +2238,10 @@ func TestListChannelUserGroups(t *testing.T) {
},
{
desc: "list user groups with service response that can't be unmarshalled",
domainID: domainID,
token: validToken,
channelID: channel.ID,
pageMeta: sdk.PageMetadata{
DomainID: domainID,
},
pageMeta: sdk.PageMetadata{},
listGroupsReq: groups.Page{
PageMeta: groups.PageMeta{
Offset: 0,
@@ -2272,11 +2269,11 @@ func TestListChannelUserGroups(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.ChannelsKind, tc.channelID, tc.listGroupsReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.ListChannelUserGroups(tc.channelID, tc.pageMeta, tc.token)
resp, err := mgsdk.ListChannelUserGroups(tc.channelID, tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -2380,7 +2377,7 @@ func TestConnect(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Assign", mock.Anything, tc.session, tc.connection.ChannelID, policies.GroupRelation, policies.ThingsKind, []string{tc.connection.ThingID}).Return(tc.svcErr)
@@ -2487,7 +2484,7 @@ func TestDisconnect(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Unassign", mock.Anything, tc.session, tc.disconnect.ChannelID, policies.GroupRelation, policies.ThingsKind, []string{tc.disconnect.ThingID}).Return(tc.svcErr)
@@ -2583,7 +2580,7 @@ func TestConnectThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Assign", mock.Anything, tc.session, tc.channelID, policies.GroupRelation, policies.ThingsKind, []string{tc.thingID}).Return(tc.svcErr)
@@ -2678,7 +2675,7 @@ func TestDisconnectThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Unassign", mock.Anything, tc.session, tc.channelID, policies.GroupRelation, policies.ThingsKind, []string{tc.thingID}).Return(tc.svcErr)
@@ -2712,6 +2709,7 @@ func TestListGroupChannels(t *testing.T) {
cases := []struct {
desc string
domainID string
token string
session mgauthn.Session
groupID string
@@ -2724,13 +2722,13 @@ func TestListGroupChannels(t *testing.T) {
err errors.SDKError
}{
{
desc: "list group channels successfully",
token: validToken,
groupID: group.ID,
desc: "list group channels successfully",
domainID: domainID,
token: validToken,
groupID: group.ID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -2756,13 +2754,13 @@ func TestListGroupChannels(t *testing.T) {
err: nil,
},
{
desc: "list group channels with invalid token",
token: invalidToken,
groupID: group.ID,
desc: "list group channels with invalid token",
domainID: domainID,
token: invalidToken,
groupID: group.ID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -2778,13 +2776,13 @@ func TestListGroupChannels(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list group channels with empty token",
token: "",
groupID: group.ID,
desc: "list group channels with empty token",
domainID: domainID,
token: "",
groupID: group.ID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -2793,13 +2791,13 @@ func TestListGroupChannels(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "list group channels with invalid group id",
token: validToken,
groupID: wrongID,
desc: "list group channels with invalid group id",
domainID: domainID,
token: validToken,
groupID: wrongID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -2815,13 +2813,13 @@ func TestListGroupChannels(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
},
{
desc: "list group channels with invalid page metadata",
token: validToken,
groupID: group.ID,
desc: "list group channels with invalid page metadata",
domainID: domainID,
token: validToken,
groupID: group.ID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
Metadata: sdk.Metadata{
"test": make(chan int),
},
@@ -2833,13 +2831,13 @@ func TestListGroupChannels(t *testing.T) {
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list group channels with service response that can't be unmarshalled",
token: validToken,
groupID: group.ID,
desc: "list group channels with service response that can't be unmarshalled",
domainID: domainID,
token: validToken,
groupID: group.ID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -2868,11 +2866,11 @@ func TestListGroupChannels(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.GroupsKind, tc.groupID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.ListGroupChannels(tc.groupID, tc.pageMeta, tc.token)
resp, err := mgsdk.ListGroupChannels(tc.groupID, tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
+10 -10
View File
@@ -59,8 +59,8 @@ func (sdk mgSDK) CreateGroup(g Group, domainID, token string) (Group, errors.SDK
return g, nil
}
func (sdk mgSDK) Groups(pm PageMetadata, token string) (GroupsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, groupsEndpoint)
func (sdk mgSDK) Groups(pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", domainID, groupsEndpoint)
url, err := sdk.withQueryParams(sdk.usersURL, endpoint, pm)
if err != nil {
return GroupsPage{}, errors.NewSDKError(err)
@@ -69,9 +69,9 @@ func (sdk mgSDK) Groups(pm PageMetadata, token string) (GroupsPage, errors.SDKEr
return sdk.getGroups(url, token)
}
func (sdk mgSDK) Parents(id string, pm PageMetadata, token string) (GroupsPage, errors.SDKError) {
func (sdk mgSDK) Parents(id string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError) {
pm.Level = MaxLevel
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, groupsEndpoint)
endpoint := fmt.Sprintf("%s/%s", domainID, groupsEndpoint)
url, err := sdk.withQueryParams(fmt.Sprintf("%s/%s/%s", sdk.usersURL, endpoint, id), "parents", pm)
if err != nil {
return GroupsPage{}, errors.NewSDKError(err)
@@ -80,9 +80,9 @@ func (sdk mgSDK) Parents(id string, pm PageMetadata, token string) (GroupsPage,
return sdk.getGroups(url, token)
}
func (sdk mgSDK) Children(id string, pm PageMetadata, token string) (GroupsPage, errors.SDKError) {
func (sdk mgSDK) Children(id string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError) {
pm.Level = MaxLevel
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, groupsEndpoint)
endpoint := fmt.Sprintf("%s/%s", domainID, groupsEndpoint)
url, err := sdk.withQueryParams(fmt.Sprintf("%s/%s/%s", sdk.usersURL, endpoint, id), "children", pm)
if err != nil {
return GroupsPage{}, errors.NewSDKError(err)
@@ -197,8 +197,8 @@ func (sdk mgSDK) RemoveUserFromGroup(groupID string, req UsersRelationRequest, d
return sdkerr
}
func (sdk mgSDK) ListGroupUsers(groupID string, pm PageMetadata, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", pm.DomainID, groupsEndpoint, groupID, usersEndpoint), pm)
func (sdk mgSDK) ListGroupUsers(groupID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", domainID, groupsEndpoint, groupID, usersEndpoint), pm)
if err != nil {
return UsersPage{}, errors.NewSDKError(err)
}
@@ -214,8 +214,8 @@ func (sdk mgSDK) ListGroupUsers(groupID string, pm PageMetadata, token string) (
return up, nil
}
func (sdk mgSDK) ListGroupChannels(groupID string, pm PageMetadata, token string) (ChannelsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.thingsURL, fmt.Sprintf("%s/%s/%s/%s", pm.DomainID, groupsEndpoint, groupID, channelsEndpoint), pm)
func (sdk mgSDK) ListGroupChannels(groupID string, pm PageMetadata, domainID, token string) (ChannelsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.thingsURL, fmt.Sprintf("%s/%s/%s/%s", domainID, groupsEndpoint, groupID, channelsEndpoint), pm)
if err != nil {
return ChannelsPage{}, errors.NewSDKError(err)
}
+137 -132
View File
@@ -265,7 +265,7 @@ func TestCreateGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("CreateGroup", mock.Anything, tc.session, policies.NewGroupKind, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -305,6 +305,7 @@ func TestListGroups(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
pageMeta sdk.PageMetadata
svcReq groups.Page
@@ -315,12 +316,12 @@ func TestListGroups(t *testing.T) {
err errors.SDKError
}{
{
desc: "list groups successfully",
token: validToken,
desc: "list groups successfully",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 100,
DomainID: domainID,
Offset: offset,
Limit: 100,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -345,12 +346,12 @@ func TestListGroups(t *testing.T) {
err: nil,
},
{
desc: "list groups with invalid token",
token: invalidToken,
desc: "list groups with invalid token",
token: invalidToken,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 100,
DomainID: domainID,
Offset: offset,
Limit: 100,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -365,26 +366,27 @@ func TestListGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list groups with empty token",
token: "",
desc: "list groups with empty token",
domainID: domainID,
token: "",
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 100,
DomainID: domainID,
Offset: offset,
Limit: 100,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
svcReq: groups.Page{},
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
authenticateErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "list groups with zero limit",
token: validToken,
desc: "list groups with zero limit",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 0,
DomainID: domainID,
Offset: offset,
Limit: 0,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -410,12 +412,12 @@ func TestListGroups(t *testing.T) {
err: nil,
},
{
desc: "list groups with limit greater than max",
token: validToken,
desc: "list groups with limit greater than max",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 110,
DomainID: domainID,
Offset: offset,
Limit: 110,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -424,12 +426,12 @@ func TestListGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
},
{
desc: "list groups with given name",
token: validToken,
desc: "list groups with given name",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
Metadata: sdk.Metadata{
"name": "user_89",
},
@@ -461,13 +463,13 @@ func TestListGroups(t *testing.T) {
err: nil,
},
{
desc: "list groups with invalid level",
token: validToken,
desc: "list groups with invalid level",
token: validToken,
domainID: domainID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 100,
Level: 6,
DomainID: domainID,
Offset: offset,
Limit: 100,
Level: 6,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -476,15 +478,15 @@ func TestListGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidLevel), http.StatusBadRequest),
},
{
desc: "list groups with invalid page metadata",
token: validToken,
desc: "list groups with invalid page metadata",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
Metadata: sdk.Metadata{
"key": make(chan int),
},
DomainID: domainID,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -493,12 +495,12 @@ func TestListGroups(t *testing.T) {
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list groups with service response that cannot be unmarshalled",
token: validToken,
desc: "list groups with service response that cannot be unmarshalled",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -528,11 +530,11 @@ func TestListGroups(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.UsersKind, "", tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Groups(tc.pageMeta, tc.token)
resp, err := mgsdk.Groups(tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -572,6 +574,7 @@ func TestListParentGroups(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
pageMeta sdk.PageMetadata
parentID string
@@ -584,12 +587,12 @@ func TestListParentGroups(t *testing.T) {
}{
{
desc: "list parent groups successfully",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -617,12 +620,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with invalid token",
domainID: domainID,
token: invalidToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -641,12 +644,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with empty token",
domainID: domainID,
token: "",
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -656,12 +659,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with zero limit",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 0,
DomainID: domainID,
Offset: offset,
Limit: 0,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -689,12 +692,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with limit greater than max",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 110,
DomainID: domainID,
Offset: offset,
Limit: 110,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -704,12 +707,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with given metadata",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
Metadata: sdk.Metadata{
"name": "user_89",
},
@@ -743,12 +746,12 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with invalid page metadata",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
Metadata: sdk.Metadata{
"key": make(chan int),
},
@@ -761,6 +764,7 @@ func TestListParentGroups(t *testing.T) {
},
{
desc: "list parent groups with service response that cannot be unmarshalled",
domainID: domainID,
token: validToken,
parentID: parentID,
pageMeta: sdk.PageMetadata{
@@ -799,11 +803,11 @@ func TestListParentGroups(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.UsersKind, "", tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Parents(tc.parentID, tc.pageMeta, tc.token)
resp, err := mgsdk.Parents(tc.parentID, tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -844,6 +848,7 @@ func TestListChildrenGroups(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
childID string
pageMeta sdk.PageMetadata
@@ -855,13 +860,13 @@ func TestListChildrenGroups(t *testing.T) {
err errors.SDKError
}{
{
desc: "list children groups successfully",
token: validToken,
childID: childID,
desc: "list children groups successfully",
domainID: domainID,
token: validToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -888,13 +893,13 @@ func TestListChildrenGroups(t *testing.T) {
err: nil,
},
{
desc: "list children groups with invalid token",
token: invalidToken,
childID: childID,
desc: "list children groups with invalid token",
domainID: domainID,
token: invalidToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -912,13 +917,13 @@ func TestListChildrenGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list children groups with empty token",
token: "",
childID: childID,
desc: "list children groups with empty token",
domainID: domainID,
token: "",
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -927,13 +932,13 @@ func TestListChildrenGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "list children groups with zero limit",
token: validToken,
childID: childID,
desc: "list children groups with zero limit",
domainID: domainID,
token: validToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 0,
DomainID: domainID,
Offset: offset,
Limit: 0,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -960,12 +965,12 @@ func TestListChildrenGroups(t *testing.T) {
err: nil,
},
{
desc: "list children groups with limit greater than max",
token: validToken,
desc: "list children groups with limit greater than max",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: 110,
DomainID: domainID,
Offset: offset,
Limit: 110,
},
svcReq: groups.Page{},
svcRes: groups.Page{},
@@ -974,13 +979,13 @@ func TestListChildrenGroups(t *testing.T) {
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
},
{
desc: "list children groups with given metadata",
token: validToken,
childID: childID,
desc: "list children groups with given metadata",
domainID: domainID,
token: validToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
Metadata: sdk.Metadata{
"name": "user_89",
},
@@ -1013,13 +1018,13 @@ func TestListChildrenGroups(t *testing.T) {
err: nil,
},
{
desc: "list children groups with invalid page metadata",
token: validToken,
childID: childID,
desc: "list children groups with invalid page metadata",
domainID: domainID,
token: validToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
Metadata: sdk.Metadata{
"key": make(chan int),
},
@@ -1031,13 +1036,13 @@ func TestListChildrenGroups(t *testing.T) {
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list children groups with service response that cannot be unmarshalled",
token: validToken,
childID: childID,
desc: "list children groups with service response that cannot be unmarshalled",
domainID: domainID,
token: validToken,
childID: childID,
pageMeta: sdk.PageMetadata{
Offset: offset,
Limit: limit,
DomainID: domainID,
Offset: offset,
Limit: limit,
},
svcReq: groups.Page{
PageMeta: groups.PageMeta{
@@ -1070,11 +1075,11 @@ func TestListChildrenGroups(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ListGroups", mock.Anything, tc.session, policies.UsersKind, "", tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Children(tc.childID, tc.pageMeta, tc.token)
resp, err := mgsdk.Children(tc.childID, tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -1179,7 +1184,7 @@ func TestViewGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ViewGroup", mock.Anything, tc.session, tc.groupID).Return(tc.svcRes, tc.svcErr)
@@ -1273,7 +1278,7 @@ func TestViewGroupPermissions(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("ViewGroupPerms", mock.Anything, tc.session, tc.groupID).Return(tc.svcRes, tc.svcErr)
@@ -1462,7 +1467,7 @@ func TestUpdateGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("UpdateGroup", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -1573,7 +1578,7 @@ func TestEnableGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("EnableGroup", mock.Anything, tc.session, tc.groupID).Return(tc.svcRes, tc.svcErr)
@@ -1684,7 +1689,7 @@ func TestDisableGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("DisableGroup", mock.Anything, tc.session, tc.groupID).Return(tc.svcRes, tc.svcErr)
@@ -1764,7 +1769,7 @@ func TestDeleteGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("DeleteGroup", mock.Anything, tc.session, tc.groupID).Return(tc.svcErr)
@@ -1888,7 +1893,7 @@ func TestAddUserToGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Assign", mock.Anything, tc.session, tc.groupID, tc.addUserReq.Relation, policies.UsersKind, tc.addUserReq.UserIDs).Return(tc.svcErr)
@@ -2000,7 +2005,7 @@ func TestRemoveUserFromGroup(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := gsvc.On("Unassign", mock.Anything, tc.session, tc.groupID, tc.removeUserReq.Relation, policies.UsersKind, tc.removeUserReq.UserIDs).Return(tc.svcErr)
+2 -2
View File
@@ -66,8 +66,8 @@ func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitati
return invitation, nil
}
func (sdk mgSDK) Invitations(pm PageMetadata, token string) (invitations InvitationPage, err error) {
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, invitationsEndpoint)
func (sdk mgSDK) Invitations(pm PageMetadata, domainID, token string) (invitations InvitationPage, err error) {
endpoint := fmt.Sprintf("%s/%s", domainID, invitationsEndpoint)
url, err := sdk.withQueryParams(sdk.invitationsURL, endpoint, pm)
if err != nil {
+64 -63
View File
@@ -131,14 +131,14 @@ func TestSendInvitation(t *testing.T) {
Relation: invitation.Relation,
Resend: invitation.Resend,
},
svcErr: svcerr.ErrCreateEntity,
err: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity),
authenticateErr: svcerr.ErrCreateEntity,
err: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("SendInvitation", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr)
@@ -206,20 +206,20 @@ func TestViewInvitation(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "view invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
domainID: wrongID,
svcRes: invitations.Invitation{},
svcErr: svcerr.ErrNotFound,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
desc: "view invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
domainID: wrongID,
svcRes: invitations.Invitation{},
authenticateErr: svcerr.ErrNotFound,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcRes, tc.svcErr)
@@ -248,6 +248,7 @@ func TestListInvitation(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
pageMeta sdk.PageMetadata
svcReq invitations.Page
@@ -258,17 +259,17 @@ func TestListInvitation(t *testing.T) {
err error
}{
{
desc: "list invitations successfully",
token: validToken,
desc: "list invitations successfully",
domainID: invitation.DomainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
Offset: 0,
Limit: 10,
},
svcReq: invitations.Page{
Offset: 0,
Limit: 10,
DomainID: domainID,
DomainID: invitation.DomainID,
},
svcRes: invitations.InvitationPage{
Total: 1,
@@ -282,28 +283,28 @@ func TestListInvitation(t *testing.T) {
err: nil,
},
{
desc: "list invitations with invalid token",
token: invalidToken,
desc: "list invitations with invalid token",
domainID: invitation.DomainID,
token: invalidToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 10,
DomainID: domainID,
},
svcReq: invitations.Page{
Offset: 0,
Limit: 10,
},
svcReq: invitations.Page{
Offset: 0,
Limit: 10,
DomainID: invitation.DomainID,
},
svcRes: invitations.InvitationPage{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list invitations with empty token",
token: "",
pageMeta: sdk.PageMetadata{
DomainID: domainID,
},
desc: "list invitations with empty token",
domainID: invitation.DomainID,
token: "",
pageMeta: sdk.PageMetadata{},
svcRes: invitations.InvitationPage{},
svcErr: nil,
response: sdk.InvitationPage{},
@@ -316,15 +317,15 @@ func TestListInvitation(t *testing.T) {
svcRes: invitations.InvitationPage{},
svcErr: nil,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingDomainID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest),
},
{
desc: "list invitations with limit greater than max limit",
token: validToken,
desc: "list invitations with limit greater than max limit",
token: validToken,
domainID: invitation.DomainID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 101,
DomainID: domainID,
Offset: 0,
Limit: 101,
},
svcReq: invitations.Page{},
svcRes: invitations.InvitationPage{},
@@ -335,12 +336,12 @@ func TestListInvitation(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ListInvitations", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Invitations(tc.pageMeta, tc.token)
resp, err := mgsdk.Invitations(tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -393,17 +394,17 @@ func TestAcceptInvitation(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "accept invitation with invalid domainID",
token: validToken,
domainID: wrongID,
svcErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
desc: "accept invitation with invalid domainID",
token: validToken,
domainID: wrongID,
authenticateErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("AcceptInvitation", mock.Anything, tc.session, tc.domainID).Return(tc.svcErr)
@@ -459,17 +460,17 @@ func TestRejectInvitation(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "reject invitation with invalid domainID",
token: validToken,
domainID: wrongID,
svcErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
desc: "reject invitation with invalid domainID",
token: validToken,
domainID: wrongID,
authenticateErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("RejectInvitation", mock.Anything, tc.session, tc.domainID).Return(tc.svcErr)
@@ -529,18 +530,18 @@ func TestDeleteInvitation(t *testing.T) {
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "delete invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
domainID: wrongID,
svcErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
desc: "delete invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
domainID: wrongID,
authenticateErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: invitation.DomainID + "_" + validID, UserID: validID, DomainID: invitation.DomainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr)
+28 -34
View File
@@ -293,13 +293,9 @@ type SDK interface {
// RefreshToken receives credentials and returns user token.
//
// example:
// lt := sdk.Login{
// DomainID: "domain_id",
// }
// example:
// token, _ := sdk.RefreshToken(lt,"refresh_token")
// token, _ := sdk.RefreshToken("refresh_token")
// fmt.Println(token)
RefreshToken(lt Login, token string) (Token, errors.SDKError)
RefreshToken(token string) (Token, errors.SDKError)
// ListUserChannels list all channels belongs a particular user id.
//
@@ -391,9 +387,9 @@ type SDK interface {
// Limit: 10,
// Name: "My Thing",
// }
// things, _ := sdk.Things(pm, "token")
// things, _ := sdk.Things(pm, "domainID", "token")
// fmt.Println(things)
Things(pm PageMetadata, token string) (ThingsPage, errors.SDKError)
Things(pm PageMetadata, domainID, token string) (ThingsPage, errors.SDKError)
// ThingsByChannel returns page of things that are connected to specified channel.
//
@@ -403,9 +399,9 @@ type SDK interface {
// Limit: 10,
// Name: "My Thing",
// }
// things, _ := sdk.ThingsByChannel("channelID", pm, "token")
// things, _ := sdk.ThingsByChannel("channelID", pm, "domainID", "token")
// fmt.Println(things)
ThingsByChannel(chanID string, pm PageMetadata, token string) (ThingsPage, errors.SDKError)
ThingsByChannel(chanID string, pm PageMetadata, domainID, token string) (ThingsPage, errors.SDKError)
// Thing returns thing object by id.
//
@@ -497,9 +493,9 @@ type SDK interface {
// Limit: 10,
// Permission: "edit", // available Options: "administrator", "administrator", "delete", edit", "view", "share", "owner", "owner", "admin", "editor", "contributor", "editor", "viewer", "guest", "create"
// }
// users, _ := sdk.ListThingUsers("thing_id", pm, "token")
// users, _ := sdk.ListThingUsers("thing_id", pm, "domainID", "token")
// fmt.Println(users)
ListThingUsers(thingID string, pm PageMetadata, token string) (UsersPage, errors.SDKError)
ListThingUsers(thingID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError)
// DeleteThing deletes a thing with the given id.
//
@@ -529,9 +525,9 @@ type SDK interface {
// Limit: 10,
// Name: "My Group",
// }
// groups, _ := sdk.Groups(pm, "token")
// groups, _ := sdk.Groups(pm, "domainID", "token")
// fmt.Println(groups)
Groups(pm PageMetadata, token string) (GroupsPage, errors.SDKError)
Groups(pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError)
// Parents returns page of users groups.
//
@@ -541,9 +537,9 @@ type SDK interface {
// Limit: 10,
// Name: "My Group",
// }
// groups, _ := sdk.Parents("groupID", pm, "token")
// groups, _ := sdk.Parents("groupID", pm, "domainID", "token")
// fmt.Println(groups)
Parents(id string, pm PageMetadata, token string) (GroupsPage, errors.SDKError)
Parents(id string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError)
// Children returns page of users groups.
//
@@ -553,9 +549,9 @@ type SDK interface {
// Limit: 10,
// Name: "My Group",
// }
// groups, _ := sdk.Children("groupID", pm, "token")
// groups, _ := sdk.Children("groupID", pm, "domainID", "token")
// fmt.Println(groups)
Children(id string, pm PageMetadata, token string) (GroupsPage, errors.SDKError)
Children(id string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError)
// Group returns users group object by id.
//
@@ -629,9 +625,9 @@ type SDK interface {
// Limit: 10,
// Permission: "edit", // available Options: "administrator", "administrator", "delete", edit", "view", "share", "owner", "owner", "admin", "editor", "contributor", "editor", "viewer", "guest", "create"
// }
// groups, _ := sdk.ListGroupUsers("groupID", pm, "token")
// groups, _ := sdk.ListGroupUsers("groupID", pm, "domainID", "token")
// fmt.Println(groups)
ListGroupUsers(groupID string, pm PageMetadata, token string) (UsersPage, errors.SDKError)
ListGroupUsers(groupID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError)
// ListGroupChannels list all channels in the group id .
//
@@ -639,12 +635,11 @@ type SDK interface {
// pm := sdk.PageMetadata{
// Offset: 0,
// Limit: 10,
// DomainID: "domain"
// Permission: "edit", // available Options: "administrator", "administrator", "delete", edit", "view", "share", "owner", "owner", "admin", "editor", "contributor", "editor", "viewer", "guest", "create"
// }
// groups, _ := sdk.ListGroupChannels("groupID", pm, "token")
// groups, _ := sdk.ListGroupChannels("groupID", pm, "domainID", "token")
// fmt.Println(groups)
ListGroupChannels(groupID string, pm PageMetadata, token string) (ChannelsPage, errors.SDKError)
ListGroupChannels(groupID string, pm PageMetadata, domainID, token string) (ChannelsPage, errors.SDKError)
// DeleteGroup delete given group id.
//
@@ -673,11 +668,10 @@ type SDK interface {
// Offset: 0,
// Limit: 10,
// Name: "My Channel",
// Domain: "domainID"
// }
// channels, _ := sdk.Channels(pm, "token")
// channels, _ := sdk.Channels(pm, "domainID", "token")
// fmt.Println(channels)
Channels(pm PageMetadata, token string) (ChannelsPage, errors.SDKError)
Channels(pm PageMetadata, domainID, token string) (ChannelsPage, errors.SDKError)
// ChannelsByThing returns page of channels that are connected to specified thing.
//
@@ -763,9 +757,9 @@ type SDK interface {
// Limit: 10,
// Permission: "edit", // available Options: "administrator", "administrator", "delete", edit", "view", "share", "owner", "owner", "admin", "editor", "contributor", "editor", "viewer", "guest", "create"
// }
// users, _ := sdk.ListChannelUsers("channel_id", pm, "token")
// users, _ := sdk.ListChannelUsers("channel_id", pm, "domainID", "token")
// fmt.Println(users)
ListChannelUsers(channelID string, pm PageMetadata, token string) (UsersPage, errors.SDKError)
ListChannelUsers(channelID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError)
// AddUserGroupToChannel add user group to a channel.
//
@@ -795,9 +789,9 @@ type SDK interface {
// Limit: 10,
// Permission: "view",
// }
// groups, _ := sdk.ListChannelUserGroups("channel_id_1", pm, "token")
// groups, _ := sdk.ListChannelUserGroups("channel_id_1", pm, "domainID", "token")
// fmt.Println(groups)
ListChannelUserGroups(channelID string, pm PageMetadata, token string) (GroupsPage, errors.SDKError)
ListChannelUserGroups(channelID string, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError)
// DeleteChannel delete given group id.
//
@@ -956,9 +950,9 @@ type SDK interface {
// Offset: 0,
// Limit: 10,
// }
// bootstraps, _ := sdk.Bootstraps(pm, "token")
// bootstraps, _ := sdk.Bootstraps(pm, "domainID", "token")
// fmt.Println(bootstraps)
Bootstraps(pm PageMetadata, token string) (BootstrapPage, errors.SDKError)
Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError)
// Whitelist updates Thing state Config with given ID belonging to the user identified by the given token.
//
@@ -1159,9 +1153,9 @@ type SDK interface {
// Invitations returns a list of invitations.
//
// For example:
// invitations, _ := sdk.Invitations(PageMetadata{Offset: 0, Limit: 10, Domain: "domainID"}, "token")
// invitations, _ := sdk.Invitations(PageMetadata{Offset: 0, Limit: 10}, "domainID", "token")
// fmt.Println(invitations)
Invitations(pm PageMetadata, token string) (invitations InvitationPage, err error)
Invitations(pm PageMetadata, domainID, token string) (invitations InvitationPage, err error)
// AcceptInvitation accepts an invitation by adding the user to the domain that they were invited to.
//
+6 -6
View File
@@ -79,8 +79,8 @@ func (sdk mgSDK) CreateThings(things []Thing, domainID, token string) ([]Thing,
return ctr.Things, nil
}
func (sdk mgSDK) Things(pm PageMetadata, token string) (ThingsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", pm.DomainID, thingsEndpoint)
func (sdk mgSDK) Things(pm PageMetadata, domainID, token string) (ThingsPage, errors.SDKError) {
endpoint := fmt.Sprintf("%s/%s", domainID, thingsEndpoint)
url, err := sdk.withQueryParams(sdk.thingsURL, endpoint, pm)
if err != nil {
return ThingsPage{}, errors.NewSDKError(err)
@@ -99,8 +99,8 @@ func (sdk mgSDK) Things(pm PageMetadata, token string) (ThingsPage, errors.SDKEr
return cp, nil
}
func (sdk mgSDK) ThingsByChannel(chanID string, pm PageMetadata, token string) (ThingsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.thingsURL, fmt.Sprintf("%s/channels/%s/%s", pm.DomainID, chanID, thingsEndpoint), pm)
func (sdk mgSDK) ThingsByChannel(chanID string, pm PageMetadata, domainID, token string) (ThingsPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.thingsURL, fmt.Sprintf("%s/channels/%s/%s", domainID, chanID, thingsEndpoint), pm)
if err != nil {
return ThingsPage{}, errors.NewSDKError(err)
}
@@ -269,8 +269,8 @@ func (sdk mgSDK) UnshareThing(thingID string, req UsersRelationRequest, domainID
return sdkerr
}
func (sdk mgSDK) ListThingUsers(thingID string, pm PageMetadata, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", pm.DomainID, thingsEndpoint, thingID, usersEndpoint), pm)
func (sdk mgSDK) ListThingUsers(thingID string, pm PageMetadata, domainID, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, fmt.Sprintf("%s/%s/%s/%s", domainID, thingsEndpoint, thingID, usersEndpoint), pm)
if err != nil {
return UsersPage{}, errors.NewSDKError(err)
}
+80 -78
View File
@@ -190,7 +190,7 @@ func TestCreateThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("CreateThings", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -301,7 +301,7 @@ func TestCreateThings(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("CreateThings", mock.Anything, tc.session, tc.svcReq[0], tc.svcReq[1], tc.svcReq[2]).Return(tc.svcRes, tc.svcErr)
@@ -340,6 +340,7 @@ func TestListThings(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
pageMeta sdk.PageMetadata
svcReq mgclients.Page
@@ -350,12 +351,12 @@ func TestListThings(t *testing.T) {
err errors.SDKError
}{
{
desc: "list all things successfully",
token: validToken,
desc: "list all things successfully",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -381,12 +382,12 @@ func TestListThings(t *testing.T) {
},
},
{
desc: "list all things with an invalid token",
token: invalidToken,
desc: "list all things with an invalid token",
domainID: domainID,
token: invalidToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -400,12 +401,12 @@ func TestListThings(t *testing.T) {
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "list all things with limit greater than max",
token: validToken,
desc: "list all things with limit greater than max",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 1000,
DomainID: domainID,
Offset: 0,
Limit: 1000,
},
svcReq: mgclients.Page{},
svcRes: mgclients.ClientsPage{},
@@ -414,13 +415,13 @@ func TestListThings(t *testing.T) {
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
},
{
desc: "list all things with name size greater than max",
token: validToken,
desc: "list all things with name size greater than max",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Name: strings.Repeat("a", 1025),
DomainID: domainID,
Offset: 0,
Limit: 100,
Name: strings.Repeat("a", 1025),
},
svcReq: mgclients.Page{},
svcRes: mgclients.ClientsPage{},
@@ -429,13 +430,13 @@ func TestListThings(t *testing.T) {
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
},
{
desc: "list all things with status",
token: validToken,
desc: "list all things with status",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Status: mgclients.DisabledStatus.String(),
DomainID: domainID,
Offset: 0,
Limit: 100,
Status: mgclients.DisabledStatus.String(),
},
svcReq: mgclients.Page{
Offset: 0,
@@ -463,13 +464,13 @@ func TestListThings(t *testing.T) {
err: nil,
},
{
desc: "list all things with tags",
token: validToken,
desc: "list all things with tags",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Tag: "tag1",
DomainID: domainID,
Offset: 0,
Limit: 100,
Tag: "tag1",
},
svcReq: mgclients.Page{
Offset: 0,
@@ -497,15 +498,15 @@ func TestListThings(t *testing.T) {
err: nil,
},
{
desc: "list all things with invalid metadata",
token: validToken,
desc: "list all things with invalid metadata",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Metadata: map[string]interface{}{
"test": make(chan int),
},
DomainID: domainID,
},
svcReq: mgclients.Page{},
svcRes: mgclients.ClientsPage{},
@@ -514,12 +515,12 @@ func TestListThings(t *testing.T) {
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
},
{
desc: "list all things with response that can't be unmarshalled",
token: validToken,
desc: "list all things with response that can't be unmarshalled",
domainID: domainID,
token: validToken,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -550,11 +551,11 @@ func TestListThings(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.Things(tc.pageMeta, tc.token)
resp, err := mgsdk.Things(tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -588,6 +589,7 @@ func TestListThingsByChannel(t *testing.T) {
cases := []struct {
desc string
token string
domainID string
session mgauthn.Session
channelID string
pageMeta sdk.PageMetadata
@@ -600,12 +602,12 @@ func TestListThingsByChannel(t *testing.T) {
}{
{
desc: "list things successfully",
domainID: domainID,
token: validToken,
channelID: validID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -632,12 +634,12 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with an invalid token",
domainID: domainID,
token: invalidToken,
channelID: validID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -652,12 +654,12 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with empty token",
domainID: domainID,
token: "",
channelID: validID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{},
svcRes: mgclients.MembersPage{},
@@ -667,13 +669,13 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with status",
domainID: domainID,
token: validToken,
channelID: validID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
Status: mgclients.DisabledStatus.String(),
DomainID: domainID,
Offset: 0,
Limit: 100,
Status: mgclients.DisabledStatus.String(),
},
svcReq: mgclients.Page{
Offset: 0,
@@ -702,12 +704,12 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with empty channel id",
domainID: domainID,
token: validToken,
channelID: "",
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{},
svcRes: mgclients.MembersPage{},
@@ -717,6 +719,7 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with invalid metadata",
domainID: domainID,
token: validToken,
channelID: validID,
pageMeta: sdk.PageMetadata{
@@ -725,7 +728,6 @@ func TestListThingsByChannel(t *testing.T) {
Metadata: map[string]interface{}{
"test": make(chan int),
},
DomainID: domainID,
},
svcReq: mgclients.Page{},
svcRes: mgclients.MembersPage{},
@@ -735,12 +737,12 @@ func TestListThingsByChannel(t *testing.T) {
},
{
desc: "list things with response that can't be unmarshalled",
domainID: domainID,
token: validToken,
channelID: validID,
pageMeta: sdk.PageMetadata{
Offset: 0,
Limit: 100,
DomainID: domainID,
Offset: 0,
Limit: 100,
},
svcReq: mgclients.Page{
Offset: 0,
@@ -771,11 +773,11 @@ func TestListThingsByChannel(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ListClientsByGroup", mock.Anything, tc.session, tc.channelID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.ThingsByChannel(tc.channelID, tc.pageMeta, tc.token)
resp, err := mgsdk.ThingsByChannel(tc.channelID, tc.pageMeta, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
@@ -881,7 +883,7 @@ func TestViewThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ViewClient", mock.Anything, tc.session, tc.thingID).Return(tc.svcRes, tc.svcErr)
@@ -976,7 +978,7 @@ func TestViewThingPermissions(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ViewClientPerms", mock.Anything, tc.session, tc.thingID).Return(tc.svcRes, tc.svcErr)
@@ -1134,7 +1136,7 @@ func TestUpdateThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("UpdateClient", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -1286,7 +1288,7 @@ func TestUpdateThingTags(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("UpdateClientTags", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -1418,7 +1420,7 @@ func TestUpdateThingSecret(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("UpdateClientSecret", mock.Anything, tc.session, tc.thingID, tc.newSecret).Return(tc.svcRes, tc.svcErr)
@@ -1521,7 +1523,7 @@ func TestEnableThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("EnableClient", mock.Anything, tc.session, tc.thingID).Return(tc.svcRes, tc.svcErr)
@@ -1624,7 +1626,7 @@ func TestDisableThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("DisableClient", mock.Anything, tc.session, tc.thingID).Return(tc.svcRes, tc.svcErr)
@@ -1739,7 +1741,7 @@ func TestShareThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("Share", mock.Anything, tc.session, tc.thingID, tc.shareReq.Relation, tc.shareReq.UserIDs[0]).Return(tc.svcErr)
@@ -1840,7 +1842,7 @@ func TestUnshareThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("Unshare", mock.Anything, tc.session, tc.thingID, tc.shareReq.Relation, tc.shareReq.UserIDs[0]).Return(tc.svcErr)
@@ -1921,7 +1923,7 @@ func TestDeleteThing(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("DeleteClient", mock.Anything, tc.session, tc.thingID).Return(tc.svcErr)
@@ -2177,7 +2179,7 @@ func TestListUserThings(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr)
svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
+2 -7
View File
@@ -22,7 +22,6 @@ type Token struct {
type Login struct {
Identity string `json:"identity"`
Secret string `json:"secret"`
DomainID string `json:"domain_id,omitempty"`
}
func (sdk mgSDK) CreateToken(lt Login) (Token, errors.SDKError) {
@@ -45,14 +44,10 @@ func (sdk mgSDK) CreateToken(lt Login) (Token, errors.SDKError) {
return token, nil
}
func (sdk mgSDK) RefreshToken(lt Login, token string) (Token, errors.SDKError) {
data, err := json.Marshal(lt)
if err != nil {
return Token{}, errors.NewSDKError(err)
}
func (sdk mgSDK) RefreshToken(token string) (Token, errors.SDKError) {
url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, usersEndpoint, refreshTokenEndpoint)
_, body, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated)
_, body, sdkerr := sdk.processRequest(http.MethodPost, url, token, nil, nil, http.StatusCreated)
if sdkerr != nil {
return Token{}, sdkerr
}
+9 -24
View File
@@ -43,7 +43,6 @@ func TestIssueToken(t *testing.T) {
login: sdk.Login{
Identity: client.Credentials.Identity,
Secret: client.Credentials.Secret,
DomainID: validID,
},
svcRes: &magistrala.Token{
AccessToken: token.AccessToken,
@@ -59,7 +58,6 @@ func TestIssueToken(t *testing.T) {
login: sdk.Login{
Identity: invalidIdentity,
Secret: client.Credentials.Secret,
DomainID: validID,
},
svcRes: &magistrala.Token{},
svcErr: svcerr.ErrAuthentication,
@@ -71,7 +69,6 @@ func TestIssueToken(t *testing.T) {
login: sdk.Login{
Identity: client.Credentials.Identity,
Secret: "invalid",
DomainID: validID,
},
svcRes: &magistrala.Token{},
svcErr: svcerr.ErrLogin,
@@ -83,7 +80,6 @@ func TestIssueToken(t *testing.T) {
login: sdk.Login{
Identity: "",
Secret: client.Credentials.Secret,
DomainID: validID,
},
svcRes: &magistrala.Token{},
svcErr: nil,
@@ -95,7 +91,6 @@ func TestIssueToken(t *testing.T) {
login: sdk.Login{
Identity: client.Credentials.Identity,
Secret: "",
DomainID: validID,
},
svcRes: &magistrala.Token{},
svcErr: nil,
@@ -105,12 +100,12 @@ func TestIssueToken(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Identity, tc.login.Secret, tc.login.DomainID).Return(tc.svcRes, tc.svcErr)
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Identity, tc.login.Secret).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.CreateToken(tc.login)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Identity, tc.login.Secret, tc.login.DomainID)
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Identity, tc.login.Secret)
assert.True(t, ok)
}
svcCall.Unset()
@@ -132,7 +127,6 @@ func TestRefreshToken(t *testing.T) {
cases := []struct {
desc string
token string
login sdk.Login
svcRes *magistrala.Token
svcErr error
identifyErr error
@@ -142,9 +136,6 @@ func TestRefreshToken(t *testing.T) {
{
desc: "refresh token successfully",
token: token.RefreshToken,
login: sdk.Login{
DomainID: validID,
},
svcRes: &magistrala.Token{
AccessToken: token.AccessToken,
RefreshToken: &token.RefreshToken,
@@ -154,22 +145,16 @@ func TestRefreshToken(t *testing.T) {
err: nil,
},
{
desc: "refresh token with invalid token",
token: invalidToken,
login: sdk.Login{
DomainID: validID,
},
desc: "refresh token with invalid token",
token: invalidToken,
svcRes: nil,
identifyErr: svcerr.ErrAuthentication,
response: sdk.Token{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "refresh token with empty token",
token: "",
login: sdk.Login{
DomainID: validID,
},
desc: "refresh token with empty token",
token: "",
response: sdk.Token{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
@@ -177,12 +162,12 @@ func TestRefreshToken(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tc.identifyErr)
svcCall := svc.On("RefreshToken", mock.Anything, mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tc.token, tc.login.DomainID).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.RefreshToken(tc.login, tc.token)
svcCall := svc.On("RefreshToken", mock.Anything, mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tc.token).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.RefreshToken(tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "RefreshToken", mock.Anything, mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tc.token, tc.login.DomainID)
ok := svcCall.Parent.AssertCalled(t, "RefreshToken", mock.Anything, mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, tc.token)
assert.True(t, ok)
}
svcCall.Unset()
+2 -2
View File
@@ -2131,7 +2131,7 @@ func TestListMembers(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ListMembers", mock.Anything, tc.session, "groups", tc.groupID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
@@ -2394,7 +2394,7 @@ func TestListUserGroups(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = mgauthn.Session{DomainUserID: validID, UserID: validID, DomainID: domainID}
tc.session = mgauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ListGroups", mock.Anything, tc.session, "users", tc.userID, tc.svcReq).Return(tc.svcRes, tc.svcErr)
+126 -126
View File
@@ -206,9 +206,9 @@ func (_m *SDK) BootstrapSecure(externalID string, externalKey string, cryptoKey
return r0, r1
}
// Bootstraps provides a mock function with given fields: pm, token
func (_m *SDK) Bootstraps(pm sdk.PageMetadata, token string) (sdk.BootstrapPage, errors.SDKError) {
ret := _m.Called(pm, token)
// Bootstraps provides a mock function with given fields: pm, domainID, token
func (_m *SDK) Bootstraps(pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) {
ret := _m.Called(pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Bootstraps")
@@ -216,17 +216,17 @@ func (_m *SDK) Bootstraps(pm sdk.PageMetadata, token string) (sdk.BootstrapPage,
var r0 sdk.BootstrapPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.BootstrapPage, errors.SDKError)); ok {
return rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok {
return rf(pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.BootstrapPage); ok {
r0 = rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.BootstrapPage); ok {
r0 = rf(pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.BootstrapPage)
}
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(pm, token)
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -296,9 +296,9 @@ func (_m *SDK) ChannelPermissions(id string, domainID string, token string) (sdk
return r0, r1
}
// Channels provides a mock function with given fields: pm, token
func (_m *SDK) Channels(pm sdk.PageMetadata, token string) (sdk.ChannelsPage, errors.SDKError) {
ret := _m.Called(pm, token)
// Channels provides a mock function with given fields: pm, domainID, token
func (_m *SDK) Channels(pm sdk.PageMetadata, domainID string, token string) (sdk.ChannelsPage, errors.SDKError) {
ret := _m.Called(pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Channels")
@@ -306,17 +306,17 @@ func (_m *SDK) Channels(pm sdk.PageMetadata, token string) (sdk.ChannelsPage, er
var r0 sdk.ChannelsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.ChannelsPage, errors.SDKError)); ok {
return rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.ChannelsPage, errors.SDKError)); ok {
return rf(pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.ChannelsPage); ok {
r0 = rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.ChannelsPage); ok {
r0 = rf(pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(pm, token)
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -356,9 +356,9 @@ func (_m *SDK) ChannelsByThing(thingID string, pm sdk.PageMetadata, domainID str
return r0, r1
}
// Children provides a mock function with given fields: id, pm, token
func (_m *SDK) Children(id string, pm sdk.PageMetadata, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(id, pm, token)
// Children provides a mock function with given fields: id, pm, domainID, token
func (_m *SDK) Children(id string, pm sdk.PageMetadata, domainID string, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(id, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Children")
@@ -366,17 +366,17 @@ func (_m *SDK) Children(id string, pm sdk.PageMetadata, token string) (sdk.Group
var r0 sdk.GroupsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(id, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(id, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.GroupsPage); ok {
r0 = rf(id, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.GroupsPage); ok {
r0 = rf(id, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.GroupsPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(id, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(id, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1256,9 +1256,9 @@ func (_m *SDK) GroupPermissions(id string, domainID string, token string) (sdk.G
return r0, r1
}
// Groups provides a mock function with given fields: pm, token
func (_m *SDK) Groups(pm sdk.PageMetadata, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(pm, token)
// Groups provides a mock function with given fields: pm, domainID, token
func (_m *SDK) Groups(pm sdk.PageMetadata, domainID string, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Groups")
@@ -1266,17 +1266,17 @@ func (_m *SDK) Groups(pm sdk.PageMetadata, token string) (sdk.GroupsPage, errors
var r0 sdk.GroupsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.GroupsPage); ok {
r0 = rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.GroupsPage); ok {
r0 = rf(pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.GroupsPage)
}
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(pm, token)
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1344,9 +1344,9 @@ func (_m *SDK) Invitation(userID string, domainID string, token string) (sdk.Inv
return r0, r1
}
// Invitations provides a mock function with given fields: pm, token
func (_m *SDK) Invitations(pm sdk.PageMetadata, token string) (sdk.InvitationPage, error) {
ret := _m.Called(pm, token)
// Invitations provides a mock function with given fields: pm, domainID, token
func (_m *SDK) Invitations(pm sdk.PageMetadata, domainID string, token string) (sdk.InvitationPage, error) {
ret := _m.Called(pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Invitations")
@@ -1354,17 +1354,17 @@ func (_m *SDK) Invitations(pm sdk.PageMetadata, token string) (sdk.InvitationPag
var r0 sdk.InvitationPage
var r1 error
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.InvitationPage, error)); ok {
return rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.InvitationPage, error)); ok {
return rf(pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.InvitationPage); ok {
r0 = rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.InvitationPage); ok {
r0 = rf(pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.InvitationPage)
}
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) error); ok {
r1 = rf(pm, token)
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) error); ok {
r1 = rf(pm, domainID, token)
} else {
r1 = ret.Error(1)
}
@@ -1430,9 +1430,9 @@ func (_m *SDK) Journal(entityType string, entityID string, pm sdk.PageMetadata,
return r0, r1
}
// ListChannelUserGroups provides a mock function with given fields: channelID, pm, token
func (_m *SDK) ListChannelUserGroups(channelID string, pm sdk.PageMetadata, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(channelID, pm, token)
// ListChannelUserGroups provides a mock function with given fields: channelID, pm, domainID, token
func (_m *SDK) ListChannelUserGroups(channelID string, pm sdk.PageMetadata, domainID string, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(channelID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ListChannelUserGroups")
@@ -1440,17 +1440,17 @@ func (_m *SDK) ListChannelUserGroups(channelID string, pm sdk.PageMetadata, toke
var r0 sdk.GroupsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(channelID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(channelID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.GroupsPage); ok {
r0 = rf(channelID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.GroupsPage); ok {
r0 = rf(channelID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.GroupsPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(channelID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(channelID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1460,9 +1460,9 @@ func (_m *SDK) ListChannelUserGroups(channelID string, pm sdk.PageMetadata, toke
return r0, r1
}
// ListChannelUsers provides a mock function with given fields: channelID, pm, token
func (_m *SDK) ListChannelUsers(channelID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(channelID, pm, token)
// ListChannelUsers provides a mock function with given fields: channelID, pm, domainID, token
func (_m *SDK) ListChannelUsers(channelID string, pm sdk.PageMetadata, domainID string, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(channelID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ListChannelUsers")
@@ -1470,17 +1470,17 @@ func (_m *SDK) ListChannelUsers(channelID string, pm sdk.PageMetadata, token str
var r0 sdk.UsersPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(channelID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(channelID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.UsersPage); ok {
r0 = rf(channelID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.UsersPage); ok {
r0 = rf(channelID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.UsersPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(channelID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(channelID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1520,9 +1520,9 @@ func (_m *SDK) ListDomainUsers(domainID string, pm sdk.PageMetadata, token strin
return r0, r1
}
// ListGroupChannels provides a mock function with given fields: groupID, pm, token
func (_m *SDK) ListGroupChannels(groupID string, pm sdk.PageMetadata, token string) (sdk.ChannelsPage, errors.SDKError) {
ret := _m.Called(groupID, pm, token)
// ListGroupChannels provides a mock function with given fields: groupID, pm, domainID, token
func (_m *SDK) ListGroupChannels(groupID string, pm sdk.PageMetadata, domainID string, token string) (sdk.ChannelsPage, errors.SDKError) {
ret := _m.Called(groupID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ListGroupChannels")
@@ -1530,17 +1530,17 @@ func (_m *SDK) ListGroupChannels(groupID string, pm sdk.PageMetadata, token stri
var r0 sdk.ChannelsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.ChannelsPage, errors.SDKError)); ok {
return rf(groupID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.ChannelsPage, errors.SDKError)); ok {
return rf(groupID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.ChannelsPage); ok {
r0 = rf(groupID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.ChannelsPage); ok {
r0 = rf(groupID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.ChannelsPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(groupID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(groupID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1550,9 +1550,9 @@ func (_m *SDK) ListGroupChannels(groupID string, pm sdk.PageMetadata, token stri
return r0, r1
}
// ListGroupUsers provides a mock function with given fields: groupID, pm, token
func (_m *SDK) ListGroupUsers(groupID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(groupID, pm, token)
// ListGroupUsers provides a mock function with given fields: groupID, pm, domainID, token
func (_m *SDK) ListGroupUsers(groupID string, pm sdk.PageMetadata, domainID string, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(groupID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ListGroupUsers")
@@ -1560,17 +1560,17 @@ func (_m *SDK) ListGroupUsers(groupID string, pm sdk.PageMetadata, token string)
var r0 sdk.UsersPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(groupID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(groupID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.UsersPage); ok {
r0 = rf(groupID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.UsersPage); ok {
r0 = rf(groupID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.UsersPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(groupID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(groupID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1610,9 +1610,9 @@ func (_m *SDK) ListSubscriptions(pm sdk.PageMetadata, token string) (sdk.Subscri
return r0, r1
}
// ListThingUsers provides a mock function with given fields: thingID, pm, token
func (_m *SDK) ListThingUsers(thingID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(thingID, pm, token)
// ListThingUsers provides a mock function with given fields: thingID, pm, domainID, token
func (_m *SDK) ListThingUsers(thingID string, pm sdk.PageMetadata, domainID string, token string) (sdk.UsersPage, errors.SDKError) {
ret := _m.Called(thingID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ListThingUsers")
@@ -1620,17 +1620,17 @@ func (_m *SDK) ListThingUsers(thingID string, pm sdk.PageMetadata, token string)
var r0 sdk.UsersPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(thingID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.UsersPage, errors.SDKError)); ok {
return rf(thingID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.UsersPage); ok {
r0 = rf(thingID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.UsersPage); ok {
r0 = rf(thingID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.UsersPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(thingID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(thingID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1790,9 +1790,9 @@ func (_m *SDK) Members(groupID string, meta sdk.PageMetadata, token string) (sdk
return r0, r1
}
// Parents provides a mock function with given fields: id, pm, token
func (_m *SDK) Parents(id string, pm sdk.PageMetadata, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(id, pm, token)
// Parents provides a mock function with given fields: id, pm, domainID, token
func (_m *SDK) Parents(id string, pm sdk.PageMetadata, domainID string, token string) (sdk.GroupsPage, errors.SDKError) {
ret := _m.Called(id, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Parents")
@@ -1800,17 +1800,17 @@ func (_m *SDK) Parents(id string, pm sdk.PageMetadata, token string) (sdk.Groups
var r0 sdk.GroupsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(id, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.GroupsPage, errors.SDKError)); ok {
return rf(id, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.GroupsPage); ok {
r0 = rf(id, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.GroupsPage); ok {
r0 = rf(id, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.GroupsPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(id, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(id, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -1850,9 +1850,9 @@ func (_m *SDK) ReadMessages(pm sdk.MessagePageMetadata, chanID string, token str
return r0, r1
}
// RefreshToken provides a mock function with given fields: lt, token
func (_m *SDK) RefreshToken(lt sdk.Login, token string) (sdk.Token, errors.SDKError) {
ret := _m.Called(lt, token)
// RefreshToken provides a mock function with given fields: token
func (_m *SDK) RefreshToken(token string) (sdk.Token, errors.SDKError) {
ret := _m.Called(token)
if len(ret) == 0 {
panic("no return value specified for RefreshToken")
@@ -1860,17 +1860,17 @@ func (_m *SDK) RefreshToken(lt sdk.Login, token string) (sdk.Token, errors.SDKEr
var r0 sdk.Token
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(sdk.Login, string) (sdk.Token, errors.SDKError)); ok {
return rf(lt, token)
if rf, ok := ret.Get(0).(func(string) (sdk.Token, errors.SDKError)); ok {
return rf(token)
}
if rf, ok := ret.Get(0).(func(sdk.Login, string) sdk.Token); ok {
r0 = rf(lt, token)
if rf, ok := ret.Get(0).(func(string) sdk.Token); ok {
r0 = rf(token)
} else {
r0 = ret.Get(0).(sdk.Token)
}
if rf, ok := ret.Get(1).(func(sdk.Login, string) errors.SDKError); ok {
r1 = rf(lt, token)
if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok {
r1 = rf(token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -2236,9 +2236,9 @@ func (_m *SDK) ThingPermissions(id string, domainID string, token string) (sdk.T
return r0, r1
}
// Things provides a mock function with given fields: pm, token
func (_m *SDK) Things(pm sdk.PageMetadata, token string) (sdk.ThingsPage, errors.SDKError) {
ret := _m.Called(pm, token)
// Things provides a mock function with given fields: pm, domainID, token
func (_m *SDK) Things(pm sdk.PageMetadata, domainID string, token string) (sdk.ThingsPage, errors.SDKError) {
ret := _m.Called(pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for Things")
@@ -2246,17 +2246,17 @@ func (_m *SDK) Things(pm sdk.PageMetadata, token string) (sdk.ThingsPage, errors
var r0 sdk.ThingsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.ThingsPage, errors.SDKError)); ok {
return rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.ThingsPage, errors.SDKError)); ok {
return rf(pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.ThingsPage); ok {
r0 = rf(pm, token)
if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.ThingsPage); ok {
r0 = rf(pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.ThingsPage)
}
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(pm, token)
if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
@@ -2266,9 +2266,9 @@ func (_m *SDK) Things(pm sdk.PageMetadata, token string) (sdk.ThingsPage, errors
return r0, r1
}
// ThingsByChannel provides a mock function with given fields: chanID, pm, token
func (_m *SDK) ThingsByChannel(chanID string, pm sdk.PageMetadata, token string) (sdk.ThingsPage, errors.SDKError) {
ret := _m.Called(chanID, pm, token)
// ThingsByChannel provides a mock function with given fields: chanID, pm, domainID, token
func (_m *SDK) ThingsByChannel(chanID string, pm sdk.PageMetadata, domainID string, token string) (sdk.ThingsPage, errors.SDKError) {
ret := _m.Called(chanID, pm, domainID, token)
if len(ret) == 0 {
panic("no return value specified for ThingsByChannel")
@@ -2276,17 +2276,17 @@ func (_m *SDK) ThingsByChannel(chanID string, pm sdk.PageMetadata, token string)
var r0 sdk.ThingsPage
var r1 errors.SDKError
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) (sdk.ThingsPage, errors.SDKError)); ok {
return rf(chanID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) (sdk.ThingsPage, errors.SDKError)); ok {
return rf(chanID, pm, domainID, token)
}
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string) sdk.ThingsPage); ok {
r0 = rf(chanID, pm, token)
if rf, ok := ret.Get(0).(func(string, sdk.PageMetadata, string, string) sdk.ThingsPage); ok {
r0 = rf(chanID, pm, domainID, token)
} else {
r0 = ret.Get(0).(sdk.ThingsPage)
}
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string) errors.SDKError); ok {
r1 = rf(chanID, pm, token)
if rf, ok := ret.Get(1).(func(string, sdk.PageMetadata, string, string) errors.SDKError); ok {
r1 = rf(chanID, pm, domainID, token)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(errors.SDKError)
-1
View File
@@ -296,7 +296,6 @@ func (ps *provisionService) createTokenIfEmpty(token string) (string, error) {
u := sdk.Login{
Identity: ps.conf.Server.MgUser,
Secret: ps.conf.Server.MgPass,
DomainID: ps.conf.Server.MgDomainID,
}
tkn, err := ps.sdk.CreateToken(u)
if err != nil {
-1
View File
@@ -221,7 +221,6 @@ func TestCert(t *testing.T) {
login := sdk.Login{
Identity: c.config.Server.MgUser,
Secret: c.config.Server.MgPass,
DomainID: c.config.Server.MgDomainID,
}
mgsdk.On("CreateToken", login).Return(sdk.Token{AccessToken: validToken}, c.sdkTokenErr)
cert, key, err := svc.Cert(c.domainID, c.token, c.thingID, c.ttl)
+1 -1
View File
@@ -28,7 +28,7 @@ func groupsHandler(svc groups.Service, authn mgauthn.Authentication, r *chi.Mux,
}
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddlewareDomain(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/{domainID}/channels", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+1 -1
View File
@@ -26,7 +26,7 @@ func clientsHandler(svc things.Service, r *chi.Mux, authn mgauthn.Authentication
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddlewareDomain(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/{domainID}/things", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+1
View File
@@ -523,6 +523,7 @@ func deleteClientEndpoint(svc things.Service) endpoint.Endpoint {
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.DeleteClient(ctx, session, req.id); err != nil {
return nil, err
}
File diff suppressed because it is too large Load Diff
+3 -6
View File
@@ -170,7 +170,6 @@ func createUser(s sdk.SDK, conf Config) (string, string, error) {
login = sdk.Login{
Identity: user.Credentials.Identity,
Secret: user.Credentials.Secret,
DomainID: domain.ID,
}
token, err = s.CreateToken(login)
if err != nil {
@@ -333,7 +332,7 @@ func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, grou
return fmt.Errorf("failed to get group %w", err)
}
}
gp, err := s.Groups(sdk.PageMetadata{}, token)
gp, err := s.Groups(sdk.PageMetadata{}, domainID, token)
if err != nil {
return fmt.Errorf("failed to get groups %w", err)
}
@@ -345,7 +344,7 @@ func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, grou
return fmt.Errorf("failed to get thing %w", err)
}
}
tp, err := s.Things(sdk.PageMetadata{}, token)
tp, err := s.Things(sdk.PageMetadata{}, domainID, token)
if err != nil {
return fmt.Errorf("failed to get things %w", err)
}
@@ -357,9 +356,7 @@ func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, grou
return fmt.Errorf("failed to get channel %w", err)
}
}
cp, err := s.Channels(sdk.PageMetadata{
DomainID: domainID,
}, token)
cp, err := s.Channels(sdk.PageMetadata{}, domainID, token)
if err != nil {
return fmt.Errorf("failed to get channels %w", err)
}
-1
View File
@@ -114,7 +114,6 @@ func Provision(conf Config) error {
token, err = s.CreateToken(sdk.Login{
Identity: user.Credentials.Identity,
Secret: user.Credentials.Secret,
DomainID: domain.ID,
})
if err != nil {
return fmt.Errorf("unable to login user: %w", err)
+7 -6
View File
@@ -46,7 +46,7 @@ func clientsHandler(svc users.Service, authn mgauthn.Authentication, tokenClient
opts...,
), "register_client").ServeHTTP)
default:
r.With(api.AuthenticateMiddleware(authn)).Post("/", otelhttp.NewHandler(kithttp.NewServer(
r.With(api.AuthenticateMiddleware(authn, false)).Post("/", otelhttp.NewHandler(kithttp.NewServer(
registrationEndpoint(svc, selfRegister),
decodeCreateClientReq,
api.EncodeResponse,
@@ -55,7 +55,7 @@ func clientsHandler(svc users.Service, authn mgauthn.Authentication, tokenClient
}
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn))
r.Use(api.AuthenticateMiddleware(authn, false))
r.Get("/profile", otelhttp.NewHandler(kithttp.NewServer(
viewProfileEndpoint(svc),
@@ -158,13 +158,17 @@ func clientsHandler(svc users.Service, authn mgauthn.Authentication, tokenClient
})
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn))
r.Use(api.AuthenticateMiddleware(authn, false))
r.Put("/password/reset", otelhttp.NewHandler(kithttp.NewServer(
passwordResetEndpoint(svc),
decodePasswordReset,
api.EncodeResponse,
opts...,
), "password_reset").ServeHTTP)
})
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
// Ideal location: users service, groups endpoint.
// Reason for placing here :
@@ -465,9 +469,6 @@ func decodeRefreshToken(_ context.Context, r *http.Request) (interface{}, error)
}
req := tokenReq{RefreshToken: apiutil.ExtractBearerToken(r)}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
+18 -18
View File
@@ -2485,7 +2485,7 @@ func TestListUsersByUserGroupId(t *testing.T) {
token: tc.token,
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: domainID}, mock.Anything, mock.Anything, mock.Anything).Return(
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID}, mock.Anything, mock.Anything, mock.Anything).Return(
mgclients.MembersPage{
Page: tc.listUsersResponse.Page,
Members: tc.listUsersResponse.Clients,
@@ -2823,7 +2823,7 @@ func TestListUsersByChannelID(t *testing.T) {
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: domainID}, mock.Anything, mock.Anything, mock.Anything).Return(
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID}, mock.Anything, mock.Anything, mock.Anything).Return(
mgclients.MembersPage{
Page: tc.listUsersResponse.Page,
Members: tc.listUsersResponse.Clients,
@@ -3167,7 +3167,7 @@ func TestListUsersByDomainID(t *testing.T) {
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: domainID}, mock.Anything, mock.Anything, mock.Anything).Return(
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID}, mock.Anything, mock.Anything, mock.Anything).Return(
mgclients.MembersPage{
Page: tc.listUsersResponse.Page,
Members: tc.listUsersResponse.Clients,
@@ -3481,7 +3481,7 @@ func TestListUsersByThingID(t *testing.T) {
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: domainID}, mock.Anything, mock.Anything, mock.Anything).Return(
svcCall := svc.On("ListMembers", mock.Anything, mgauthn.Session{UserID: validID, DomainID: validID, DomainUserID: validID + "_" + validID}, mock.Anything, mock.Anything, mock.Anything).Return(
mgclients.MembersPage{
Page: tc.listUsersResponse.Page,
Members: tc.listUsersResponse.Clients,
@@ -3515,7 +3515,7 @@ func TestAssignUsers(t *testing.T) {
desc: "assign users to a group successfully",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "member",
@@ -3553,7 +3553,7 @@ func TestAssignUsers(t *testing.T) {
desc: "assign users to a group with empty relation",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "",
@@ -3566,7 +3566,7 @@ func TestAssignUsers(t *testing.T) {
desc: "assign users to a group with empty user ids",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "member",
@@ -3579,7 +3579,7 @@ func TestAssignUsers(t *testing.T) {
desc: "assign users to a group with invalid request body",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: map[string]interface{}{
"relation": make(chan int),
@@ -3629,7 +3629,7 @@ func TestUnassignUsers(t *testing.T) {
desc: "unassign users from a group successfully",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "member",
@@ -3667,7 +3667,7 @@ func TestUnassignUsers(t *testing.T) {
desc: "unassign users from a group with empty relation",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "",
@@ -3680,7 +3680,7 @@ func TestUnassignUsers(t *testing.T) {
desc: "unassign users from a group with empty user ids",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
Relation: "member",
@@ -3693,7 +3693,7 @@ func TestUnassignUsers(t *testing.T) {
desc: "unassign users from a group with invalid request body",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: map[string]interface{}{
"relation": make(chan int),
@@ -3743,7 +3743,7 @@ func TestAssignGroups(t *testing.T) {
desc: "assign groups to a parent group successfully",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
GroupIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
@@ -3778,7 +3778,7 @@ func TestAssignGroups(t *testing.T) {
desc: "assign groups to a parent group with empty parent group id",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: "",
reqBody: groupReqBody{
GroupIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
@@ -3790,7 +3790,7 @@ func TestAssignGroups(t *testing.T) {
desc: "assign groups to a parent group with empty group ids",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
GroupIDs: []string{},
@@ -3802,7 +3802,7 @@ func TestAssignGroups(t *testing.T) {
desc: "assign groups to a parent group with invalid request body",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: map[string]interface{}{
"group_ids": make(chan int),
@@ -3852,7 +3852,7 @@ func TestUnassignGroups(t *testing.T) {
desc: "unassign groups from a parent group successfully",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: validID,
reqBody: groupReqBody{
GroupIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
@@ -3887,7 +3887,7 @@ func TestUnassignGroups(t *testing.T) {
desc: "unassign groups from a parent group with empty group id",
domainID: domainID,
token: validToken,
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID},
authnRes: mgauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
groupID: "",
reqBody: groupReqBody{
GroupIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
+2 -2
View File
@@ -424,7 +424,7 @@ func issueTokenEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
token, err := svc.IssueToken(ctx, req.Identity, req.Secret, req.DomainID)
token, err := svc.IssueToken(ctx, req.Identity, req.Secret)
if err != nil {
return nil, err
}
@@ -449,7 +449,7 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
return nil, svcerr.ErrAuthorization
}
token, err := svc.RefreshToken(ctx, session, req.RefreshToken, req.DomainID)
token, err := svc.RefreshToken(ctx, session, req.RefreshToken)
if err != nil {
return nil, err
}
+1 -1
View File
@@ -31,7 +31,7 @@ func groupsHandler(svc groups.Service, authn mgauthn.Authentication, r *chi.Mux,
}
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddlewareDomain(authn))
r.Use(api.AuthenticateMiddleware(authn, true))
r.Route("/{domainID}/groups", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
-2
View File
@@ -187,7 +187,6 @@ func (req changeClientStatusReq) validate() error {
type loginClientReq struct {
Identity string `json:"identity,omitempty"`
Secret string `json:"secret,omitempty"`
DomainID string `json:"domain_id,omitempty"`
}
func (req loginClientReq) validate() error {
@@ -203,7 +202,6 @@ func (req loginClientReq) validate() error {
type tokenReq struct {
RefreshToken string `json:"refresh_token,omitempty"`
DomainID string `json:"domain_id,omitempty"`
}
func (req tokenReq) validate() error {
+2 -2
View File
@@ -74,12 +74,12 @@ type Service interface {
Identify(ctx context.Context, session authn.Session) (string, error)
// IssueToken issues a new access and refresh token.
IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error)
IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error)
// RefreshToken refreshes expired access tokens.
// After an access token expires, the refresh token is used to get
// a new pair of access and refresh tokens.
RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (*magistrala.Token, error)
RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error)
// OAuthCallback handles the callback from any supported OAuth provider.
// It processes the OAuth tokens and either signs in or signs up the user based on the provided state.
+1 -6
View File
@@ -360,25 +360,20 @@ func (grte generateResetTokenEvent) Encode() (map[string]interface{}, error) {
type issueTokenEvent struct {
identity string
domainID string
}
func (ite issueTokenEvent) Encode() (map[string]interface{}, error) {
return map[string]interface{}{
"operation": issueToken,
"identity": ite.identity,
"domain_id": ite.domainID,
}, nil
}
type refreshTokenEvent struct {
domainID string
}
type refreshTokenEvent struct{}
func (rte refreshTokenEvent) Encode() (map[string]interface{}, error) {
return map[string]interface{}{
"operation": refreshToken,
"domain_id": rte.domainID,
}, nil
}
+5 -6
View File
@@ -257,15 +257,14 @@ func (es *eventStore) GenerateResetToken(ctx context.Context, email, host string
return es.Publish(ctx, event)
}
func (es *eventStore) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
token, err := es.svc.IssueToken(ctx, identity, secret, domainID)
func (es *eventStore) IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error) {
token, err := es.svc.IssueToken(ctx, identity, secret)
if err != nil {
return token, err
}
event := issueTokenEvent{
identity: identity,
domainID: domainID,
}
if err := es.Publish(ctx, event); err != nil {
@@ -275,13 +274,13 @@ func (es *eventStore) IssueToken(ctx context.Context, identity, secret, domainID
return token, nil
}
func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (*magistrala.Token, error) {
token, err := es.svc.RefreshToken(ctx, session, refreshToken, domainID)
func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error) {
token, err := es.svc.RefreshToken(ctx, session, refreshToken)
if err != nil {
return token, err
}
event := refreshTokenEvent{domainID: domainID}
event := refreshTokenEvent{}
if err := es.Publish(ctx, event); err != nil {
return token, err
+7 -7
View File
@@ -70,15 +70,15 @@ func (am *authorizationMiddleware) ListMembers(ctx context.Context, session auth
}
switch objectKind {
case policies.GroupsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil {
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.UserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil {
return clients.MembersPage{}, err
}
case policies.DomainsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil {
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.UserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil {
return clients.MembersPage{}, err
}
case policies.ThingsKind:
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil {
if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.UserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil {
return clients.MembersPage{}, err
}
default:
@@ -171,12 +171,12 @@ func (am *authorizationMiddleware) Identify(ctx context.Context, session authn.S
return am.svc.Identify(ctx, session)
}
func (am *authorizationMiddleware) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
return am.svc.IssueToken(ctx, identity, secret, domainID)
func (am *authorizationMiddleware) IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error) {
return am.svc.IssueToken(ctx, identity, secret)
}
func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (*magistrala.Token, error) {
return am.svc.RefreshToken(ctx, session, refreshToken, domainID)
func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error) {
return am.svc.RefreshToken(ctx, session, refreshToken)
}
func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, client clients.Client) (clients.Client, error) {
+4 -6
View File
@@ -49,11 +49,10 @@ func (lm *loggingMiddleware) RegisterClient(ctx context.Context, session authn.S
// IssueToken logs the issue_token request. It logs the client identity type and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) IssueToken(ctx context.Context, identity, secret, domainID string) (t *magistrala.Token, err error) {
func (lm *loggingMiddleware) IssueToken(ctx context.Context, identity, secret string) (t *magistrala.Token, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if t.AccessType != "" {
args = append(args, slog.String("access_type", t.AccessType))
@@ -65,16 +64,15 @@ func (lm *loggingMiddleware) IssueToken(ctx context.Context, identity, secret, d
}
lm.logger.Info("Issue token completed successfully", args...)
}(time.Now())
return lm.svc.IssueToken(ctx, identity, secret, domainID)
return lm.svc.IssueToken(ctx, identity, secret)
}
// RefreshToken logs the refresh_token request. It logs the refreshtoken, token type and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (t *magistrala.Token, err error) {
func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (t *magistrala.Token, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if t.AccessType != "" {
args = append(args, slog.String("access_type", t.AccessType))
@@ -86,7 +84,7 @@ func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Ses
}
lm.logger.Info("Refresh token completed successfully", args...)
}(time.Now())
return lm.svc.RefreshToken(ctx, session, refreshToken, domainID)
return lm.svc.RefreshToken(ctx, session, refreshToken)
}
// ViewClient logs the view_client request. It logs the client id and the time it took to complete the request.
+4 -4
View File
@@ -41,21 +41,21 @@ func (ms *metricsMiddleware) RegisterClient(ctx context.Context, session authn.S
}
// IssueToken instruments IssueToken method with metrics.
func (ms *metricsMiddleware) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
func (ms *metricsMiddleware) IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error) {
defer func(begin time.Time) {
ms.counter.With("method", "issue_token").Add(1)
ms.latency.With("method", "issue_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.IssueToken(ctx, identity, secret, domainID)
return ms.svc.IssueToken(ctx, identity, secret)
}
// RefreshToken instruments RefreshToken method with metrics.
func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (token *magistrala.Token, err error) {
func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (token *magistrala.Token, err error) {
defer func(begin time.Time) {
ms.counter.With("method", "refresh_token").Add(1)
ms.latency.With("method", "refresh_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RefreshToken(ctx, session, refreshToken, domainID)
return ms.svc.RefreshToken(ctx, session, refreshToken)
}
// ViewClient instruments ViewClient method with metrics.
+18 -18
View File
@@ -140,9 +140,9 @@ func (_m *Service) Identify(ctx context.Context, session authn.Session) (string,
return r0, r1
}
// IssueToken provides a mock function with given fields: ctx, identity, secret, domainID
func (_m *Service) IssueToken(ctx context.Context, identity string, secret string, domainID string) (*magistrala.Token, error) {
ret := _m.Called(ctx, identity, secret, domainID)
// IssueToken provides a mock function with given fields: ctx, identity, secret
func (_m *Service) IssueToken(ctx context.Context, identity string, secret string) (*magistrala.Token, error) {
ret := _m.Called(ctx, identity, secret)
if len(ret) == 0 {
panic("no return value specified for IssueToken")
@@ -150,19 +150,19 @@ func (_m *Service) IssueToken(ctx context.Context, identity string, secret strin
var r0 *magistrala.Token
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*magistrala.Token, error)); ok {
return rf(ctx, identity, secret, domainID)
if rf, ok := ret.Get(0).(func(context.Context, string, string) (*magistrala.Token, error)); ok {
return rf(ctx, identity, secret)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *magistrala.Token); ok {
r0 = rf(ctx, identity, secret, domainID)
if rf, ok := ret.Get(0).(func(context.Context, string, string) *magistrala.Token); ok {
r0 = rf(ctx, identity, secret)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*magistrala.Token)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, identity, secret, domainID)
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, identity, secret)
} else {
r1 = ret.Error(1)
}
@@ -272,9 +272,9 @@ func (_m *Service) OAuthCallback(ctx context.Context, client clients.Client) (cl
return r0, r1
}
// RefreshToken provides a mock function with given fields: ctx, session, refreshToken, domainID
func (_m *Service) RefreshToken(ctx context.Context, session authn.Session, refreshToken string, domainID string) (*magistrala.Token, error) {
ret := _m.Called(ctx, session, refreshToken, domainID)
// RefreshToken provides a mock function with given fields: ctx, session, refreshToken
func (_m *Service) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error) {
ret := _m.Called(ctx, session, refreshToken)
if len(ret) == 0 {
panic("no return value specified for RefreshToken")
@@ -282,19 +282,19 @@ func (_m *Service) RefreshToken(ctx context.Context, session authn.Session, refr
var r0 *magistrala.Token
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (*magistrala.Token, error)); ok {
return rf(ctx, session, refreshToken, domainID)
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (*magistrala.Token, error)); ok {
return rf(ctx, session, refreshToken)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) *magistrala.Token); ok {
r0 = rf(ctx, session, refreshToken, domainID)
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) *magistrala.Token); ok {
r0 = rf(ctx, session, refreshToken)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*magistrala.Token)
}
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok {
r1 = rf(ctx, session, refreshToken, domainID)
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
r1 = rf(ctx, session, refreshToken)
} else {
r1 = ret.Error(1)
}
+5 -15
View File
@@ -93,7 +93,7 @@ func (svc service) RegisterClient(ctx context.Context, session authn.Session, cl
return client, nil
}
func (svc service) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
func (svc service) IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error) {
dbUser, err := svc.clients.RetrieveByIdentity(ctx, identity)
if err != nil {
return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, err)
@@ -102,12 +102,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret, domainID st
return &magistrala.Token{}, errors.Wrap(svcerr.ErrLogin, err)
}
var d string
if domainID != "" {
d = domainID
}
token, err := svc.token.Issue(ctx, &magistrala.IssueReq{UserId: dbUser.ID, DomainId: &d, Type: uint32(mgauth.AccessKey)})
token, err := svc.token.Issue(ctx, &magistrala.IssueReq{UserId: dbUser.ID, Type: uint32(mgauth.AccessKey)})
if err != nil {
return &magistrala.Token{}, errors.Wrap(errIssueToken, err)
}
@@ -115,12 +110,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret, domainID st
return token, err
}
func (svc service) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (*magistrala.Token, error) {
var d string
if domainID != "" {
d = domainID
}
func (svc service) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error) {
dbUser, err := svc.clients.RetrieveByID(ctx, session.UserID)
if err != nil {
return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, err)
@@ -129,7 +119,7 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
return &magistrala.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
}
return svc.token.Refresh(ctx, &magistrala.RefreshReq{RefreshToken: refreshToken, DomainId: &d})
return svc.token.Refresh(ctx, &magistrala.RefreshReq{RefreshToken: refreshToken})
}
func (svc service) ViewClient(ctx context.Context, session authn.Session, id string) (mgclients.Client, error) {
@@ -301,7 +291,7 @@ func (svc service) UpdateClientSecret(ctx context.Context, session authn.Session
if err != nil {
return mgclients.Client{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if _, err := svc.IssueToken(ctx, dbClient.Credentials.Identity, oldSecret, ""); err != nil {
if _, err := svc.IssueToken(ctx, dbClient.Credentials.Identity, oldSecret); err != nil {
return mgclients.Client{}, err
}
newSecret, err = svc.hasher.Hash(newSecret)
+77 -79
View File
@@ -1489,7 +1489,6 @@ func TestIssueToken(t *testing.T) {
cases := []struct {
desc string
domainID string
client mgclients.Client
retrieveByIdentityResponse mgclients.Client
issueResponse *magistrala.Token
@@ -1506,7 +1505,6 @@ func TestIssueToken(t *testing.T) {
},
{
desc: "issue token for non-empty domain id",
domainID: validID,
client: client,
retrieveByIdentityResponse: rClient,
issueResponse: &magistrala.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"},
@@ -1544,20 +1542,22 @@ func TestIssueToken(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
authCall := auth.On("Issue", context.Background(), &magistrala.IssueReq{UserId: tc.client.ID, DomainId: &tc.domainID, Type: uint32(mgauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.client.Credentials.Identity, tc.client.Credentials.Secret, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity)
assert.True(t, ok, fmt.Sprintf("RetrieveByIdentity was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &magistrala.IssueReq{UserId: tc.client.ID, DomainId: &tc.domainID, Type: uint32(mgauth.AccessKey)})
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
authCall := auth.On("Issue", context.Background(), &magistrala.IssueReq{UserId: tc.client.ID, Type: uint32(mgauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.client.Credentials.Identity, tc.client.Credentials.Secret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity)
assert.True(t, ok, fmt.Sprintf("RetrieveByIdentity was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &magistrala.IssueReq{UserId: tc.client.ID, Type: uint32(mgauth.AccessKey)})
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
})
}
}
@@ -1570,7 +1570,6 @@ func TestRefreshToken(t *testing.T) {
cases := []struct {
desc string
session authn.Session
domainID string
refreshResp *magistrala.Token
refresErr error
repoResp mgclients.Client
@@ -1580,14 +1579,6 @@ func TestRefreshToken(t *testing.T) {
{
desc: "refresh token with refresh token for an existing client",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
domainID: validID,
refreshResp: &magistrala.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"},
repoResp: rClient,
err: nil,
},
{
desc: "refresh token with refresh token for empty domain id",
session: authn.Session{UserID: validID},
refreshResp: &magistrala.Token{AccessToken: validToken, RefreshToken: &validToken, AccessType: "3"},
repoResp: rClient,
err: nil,
@@ -1595,23 +1586,20 @@ func TestRefreshToken(t *testing.T) {
{
desc: "refresh token with access token for an existing client",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
domainID: validID,
refreshResp: &magistrala.Token{},
refresErr: svcerr.ErrAuthentication,
repoResp: rClient,
err: svcerr.ErrAuthentication,
},
{
desc: "refresh token with refresh token for a non-existing client",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
domainID: validID,
repoErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
desc: "refresh token with refresh token for a non-existing client",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
repoErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "refresh token with refresh token for a disable client",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
domainID: validID,
repoResp: mgclients.Client{Status: mgclients.DisabledStatus},
err: svcerr.ErrAuthentication,
},
@@ -1626,20 +1614,22 @@ func TestRefreshToken(t *testing.T) {
}
for _, tc := range cases {
authCall := authsvc.On("Refresh", context.Background(), &magistrala.RefreshReq{RefreshToken: validToken, DomainId: &tc.domainID}).Return(tc.refreshResp, tc.refresErr)
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
token, err := svc.RefreshToken(context.Background(), tc.session, validToken, tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &magistrala.RefreshReq{RefreshToken: validToken, DomainId: &tc.domainID})
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
authCall := authsvc.On("Refresh", context.Background(), &magistrala.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr)
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
token, err := svc.RefreshToken(context.Background(), tc.session, validToken)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &magistrala.RefreshReq{RefreshToken: validToken})
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
})
}
}
@@ -1689,15 +1679,17 @@ func TestGenerateResetToken(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.email).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
authCall := auth.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
svcCall := e.On("SendPasswordReset", []string{tc.email}, tc.host, client.Name, validToken).Return(tc.err)
err := svc.GenerateResetToken(context.Background(), tc.email, tc.host)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.email)
repoCall.Unset()
authCall.Unset()
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.email).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
authCall := auth.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
svcCall := e.On("SendPasswordReset", []string{tc.email}, tc.host, client.Name, validToken).Return(tc.err)
err := svc.GenerateResetToken(context.Background(), tc.email, tc.host)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.email)
repoCall.Unset()
authCall.Unset()
svcCall.Unset()
})
}
}
@@ -1775,16 +1767,18 @@ func TestResetSecret(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall1 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
err := svc.ResetSecret(context.Background(), tc.session, tc.newSecret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
repoCall1.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), validID)
}
repoCall1.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall1 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
err := svc.ResetSecret(context.Background(), tc.session, tc.newSecret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
repoCall1.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), validID)
}
repoCall1.Unset()
repoCall.Unset()
})
}
}
@@ -1824,11 +1818,13 @@ func TestViewProfile(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
_, err := svc.ViewProfile(context.Background(), tc.session)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), mock.Anything)
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
_, err := svc.ViewProfile(context.Background(), tc.session)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), mock.Anything)
repoCall.Unset()
})
}
}
@@ -1909,14 +1905,16 @@ func TestOAuthCallback(t *testing.T) {
},
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.saveResponse, tc.saveErr)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
_, err := svc.OAuthCallback(context.Background(), tc.client)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity)
repoCall.Unset()
repoCall1.Unset()
policyCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity).Return(tc.retrieveByIdentityResponse, tc.retrieveByIdentityErr)
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.saveResponse, tc.saveErr)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr)
_, err := svc.OAuthCallback(context.Background(), tc.client)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", context.Background(), tc.client.Credentials.Identity)
repoCall.Unset()
repoCall1.Unset()
policyCall.Unset()
})
}
}
+4 -4
View File
@@ -35,19 +35,19 @@ func (tm *tracingMiddleware) RegisterClient(ctx context.Context, session authn.S
}
// IssueToken traces the "IssueToken" operation of the wrapped clients.Service.
func (tm *tracingMiddleware) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
func (tm *tracingMiddleware) IssueToken(ctx context.Context, identity, secret string) (*magistrala.Token, error) {
ctx, span := tm.tracer.Start(ctx, "svc_issue_token", trace.WithAttributes(attribute.String("identity", identity)))
defer span.End()
return tm.svc.IssueToken(ctx, identity, secret, domainID)
return tm.svc.IssueToken(ctx, identity, secret)
}
// RefreshToken traces the "RefreshToken" operation of the wrapped clients.Service.
func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken, domainID string) (*magistrala.Token, error) {
func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*magistrala.Token, error) {
ctx, span := tm.tracer.Start(ctx, "svc_refresh_token", trace.WithAttributes(attribute.String("refresh_token", refreshToken)))
defer span.End()
return tm.svc.RefreshToken(ctx, session, refreshToken, domainID)
return tm.svc.RefreshToken(ctx, session, refreshToken)
}
// ViewClient traces the "ViewClient" operation of the wrapped clients.Service.