mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-1744 - Error handling with TypedError created on top existing Error (#3170)
Signed-off-by: Arvindh <arvindh91@gmail.com> Signed-off-by: Felix Gateru <felix.gateru@gmail.com> Co-authored-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
+37
-137
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/groups"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/users"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
@@ -184,146 +183,47 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
|
||||
return
|
||||
}
|
||||
|
||||
var wrapper error
|
||||
if errors.Contains(err, apiutil.ErrValidation) {
|
||||
wrapper, err = errors.Unwrap(err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Contains(err, errors.ErrTryAgain):
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
case errors.Contains(err, errors.ErrEmailAlreadyExists),
|
||||
errors.Contains(err, errors.ErrUsernameNotAvailable),
|
||||
errors.Contains(err, errors.ErrRouteNotAvailable),
|
||||
errors.Contains(err, errors.ErrChannelRouteNotAvailable),
|
||||
errors.Contains(err, errors.ErrDomainRouteNotAvailable),
|
||||
errors.Contains(err, svcerr.ErrExternalAuthProviderCouldNotUpdate):
|
||||
switch retErr := err.(type) {
|
||||
case *errors.RequestError:
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
case errors.Contains(err, svcerr.ErrAuthorization),
|
||||
errors.Contains(err, svcerr.ErrDomainAuthorization),
|
||||
errors.Contains(err, svcerr.ErrUnauthorizedPAT),
|
||||
errors.Contains(err, svcerr.ErrSuperAdminAction):
|
||||
err = unwrap(err)
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
|
||||
case errors.Contains(err, svcerr.ErrAuthentication),
|
||||
errors.Contains(err, apiutil.ErrBearerToken),
|
||||
errors.Contains(err, svcerr.ErrLogin),
|
||||
errors.Contains(err, apiutil.ErrUnsupportedTokenType):
|
||||
err = unwrap(err)
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.AuthNError:
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
case errors.Contains(err, svcerr.ErrMalformedEntity),
|
||||
errors.Contains(err, apiutil.ErrMalformedPolicy),
|
||||
errors.Contains(err, apiutil.ErrMissingSecret),
|
||||
errors.Contains(err, errors.ErrMalformedEntity),
|
||||
errors.Contains(err, apiutil.ErrMissingID),
|
||||
errors.Contains(err, apiutil.ErrInvalidVerification),
|
||||
errors.Contains(err, apiutil.ErrMissingName),
|
||||
errors.Contains(err, apiutil.ErrMissingEmail),
|
||||
errors.Contains(err, apiutil.ErrInvalidEmail),
|
||||
errors.Contains(err, apiutil.ErrMissingHost),
|
||||
errors.Contains(err, apiutil.ErrInvalidResetPass),
|
||||
errors.Contains(err, apiutil.ErrEmptyList),
|
||||
errors.Contains(err, apiutil.ErrMissingMemberKind),
|
||||
errors.Contains(err, apiutil.ErrMissingMemberType),
|
||||
errors.Contains(err, apiutil.ErrLimitSize),
|
||||
errors.Contains(err, apiutil.ErrBearerKey),
|
||||
errors.Contains(err, svcerr.ErrInvalidStatus),
|
||||
errors.Contains(err, apiutil.ErrNameSize),
|
||||
errors.Contains(err, apiutil.ErrInvalidIDFormat),
|
||||
errors.Contains(err, apiutil.ErrInvalidQueryParams),
|
||||
errors.Contains(err, apiutil.ErrMissingRelation),
|
||||
errors.Contains(err, apiutil.ErrValidation),
|
||||
errors.Contains(err, apiutil.ErrMissingPass),
|
||||
errors.Contains(err, apiutil.ErrMissingConfPass),
|
||||
errors.Contains(err, apiutil.ErrPasswordFormat),
|
||||
errors.Contains(err, svcerr.ErrInvalidRole),
|
||||
errors.Contains(err, svcerr.ErrInvalidPolicy),
|
||||
errors.Contains(err, apiutil.ErrInvitationState),
|
||||
errors.Contains(err, apiutil.ErrInvalidAPIKey),
|
||||
errors.Contains(err, svcerr.ErrViewEntity),
|
||||
errors.Contains(err, apiutil.ErrMissingCertData),
|
||||
errors.Contains(err, apiutil.ErrInvalidContact),
|
||||
errors.Contains(err, apiutil.ErrInvalidTopic),
|
||||
errors.Contains(err, apiutil.ErrInvalidCertData),
|
||||
errors.Contains(err, apiutil.ErrEmptyMessage),
|
||||
errors.Contains(err, apiutil.ErrInvalidLevel),
|
||||
errors.Contains(err, apiutil.ErrInvalidDirection),
|
||||
errors.Contains(err, apiutil.ErrInvalidEntityType),
|
||||
errors.Contains(err, apiutil.ErrMissingEntityType),
|
||||
errors.Contains(err, apiutil.ErrInvalidTimeFormat),
|
||||
errors.Contains(err, svcerr.ErrSearch),
|
||||
errors.Contains(err, apiutil.ErrEmptySearchQuery),
|
||||
errors.Contains(err, apiutil.ErrLenSearchQuery),
|
||||
errors.Contains(err, apiutil.ErrMissingDomainID),
|
||||
errors.Contains(err, apiutil.ErrMissingUserID),
|
||||
errors.Contains(err, apiutil.ErrMissingPATID),
|
||||
errors.Contains(err, apiutil.ErrMissingUsername),
|
||||
errors.Contains(err, apiutil.ErrMissingUsernameEmail),
|
||||
errors.Contains(err, apiutil.ErrMissingFirstName),
|
||||
errors.Contains(err, apiutil.ErrMissingLastName),
|
||||
errors.Contains(err, apiutil.ErrInvalidUsername),
|
||||
errors.Contains(err, apiutil.ErrMissingIdentity),
|
||||
errors.Contains(err, apiutil.ErrInvalidProfilePictureURL),
|
||||
errors.Contains(err, apiutil.ErrSelfParentingNotAllowed),
|
||||
errors.Contains(err, apiutil.ErrMissingChildrenGroupIDs),
|
||||
errors.Contains(err, apiutil.ErrMissingParentGroupID),
|
||||
errors.Contains(err, apiutil.ErrMissingConnectionType),
|
||||
errors.Contains(err, apiutil.ErrMissingRoleName),
|
||||
errors.Contains(err, apiutil.ErrMissingRoleID),
|
||||
errors.Contains(err, apiutil.ErrMissingPolicyEntityType),
|
||||
errors.Contains(err, apiutil.ErrMissingRoleMembers),
|
||||
errors.Contains(err, apiutil.ErrMissingDescription),
|
||||
errors.Contains(err, apiutil.ErrMissingEntityID),
|
||||
errors.Contains(err, apiutil.ErrInvalidRouteFormat),
|
||||
errors.Contains(err, svcerr.ErrRetainOneMember),
|
||||
errors.Contains(err, apiutil.ErrMissingRoute):
|
||||
err = unwrap(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
case errors.Contains(err, svcerr.ErrCreateEntity),
|
||||
errors.Contains(err, svcerr.ErrUpdateEntity),
|
||||
errors.Contains(err, svcerr.ErrRemoveEntity),
|
||||
errors.Contains(err, svcerr.ErrEnableClient),
|
||||
errors.Contains(err, svcerr.ErrEnableUser),
|
||||
errors.Contains(err, svcerr.ErrDisableUser):
|
||||
err = unwrap(err)
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
|
||||
case errors.Contains(err, svcerr.ErrNotFound):
|
||||
err = unwrap(err)
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
case errors.Contains(err, errors.ErrStatusAlreadyAssigned),
|
||||
errors.Contains(err, svcerr.ErrInvitationAlreadyRejected),
|
||||
errors.Contains(err, svcerr.ErrInvitationAlreadyAccepted),
|
||||
errors.Contains(err, svcerr.ErrConflict):
|
||||
err = unwrap(err)
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
|
||||
case errors.Contains(err, apiutil.ErrUnsupportedContentType):
|
||||
err = unwrap(err)
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.AuthZError:
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.MediaTypeError:
|
||||
w.WriteHeader(http.StatusUnsupportedMediaType)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.ServiceError:
|
||||
w.WriteHeader(http.StatusUnprocessableEntity)
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.NotFoundError:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
if err := json.NewEncoder(w).Encode(retErr); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
case *errors.InternalError:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
default:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if wrapper != nil {
|
||||
err = errors.Wrap(wrapper, err)
|
||||
}
|
||||
|
||||
if errorVal, ok := err.(errors.Error); ok {
|
||||
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func unwrap(err error) error {
|
||||
wrapper, err := errors.Unwrap(err)
|
||||
if wrapper != nil {
|
||||
return wrapper
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
+75
-92
@@ -218,121 +218,104 @@ func TestEncodeResponse(t *testing.T) {
|
||||
|
||||
func TestEncodeError(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
errs []error
|
||||
code int
|
||||
desc string
|
||||
err error
|
||||
code int
|
||||
hasBody bool
|
||||
checkError bool
|
||||
}{
|
||||
{
|
||||
desc: "BadRequest",
|
||||
errs: []error{
|
||||
apiutil.ErrMissingSecret,
|
||||
svcerr.ErrMalformedEntity,
|
||||
errors.ErrMalformedEntity,
|
||||
apiutil.ErrMissingID,
|
||||
apiutil.ErrEmptyList,
|
||||
apiutil.ErrMissingMemberType,
|
||||
apiutil.ErrMissingMemberKind,
|
||||
apiutil.ErrLimitSize,
|
||||
apiutil.ErrNameSize,
|
||||
svcerr.ErrViewEntity,
|
||||
},
|
||||
code: http.StatusBadRequest,
|
||||
desc: "RequestError - Missing Secret",
|
||||
err: apiutil.ErrMissingSecret,
|
||||
code: http.StatusBadRequest,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "BadRequest with validation error",
|
||||
errs: []error{
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingSecret),
|
||||
errors.Wrap(apiutil.ErrValidation, svcerr.ErrMalformedEntity),
|
||||
errors.Wrap(apiutil.ErrValidation, errors.ErrMalformedEntity),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptyList),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingMemberType),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingMemberKind),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize),
|
||||
errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize),
|
||||
},
|
||||
code: http.StatusBadRequest,
|
||||
desc: "RequestError - Missing ID",
|
||||
err: apiutil.ErrMissingID,
|
||||
code: http.StatusBadRequest,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "Unauthorized",
|
||||
errs: []error{
|
||||
svcerr.ErrAuthentication,
|
||||
svcerr.ErrAuthentication,
|
||||
apiutil.ErrBearerToken,
|
||||
},
|
||||
code: http.StatusUnauthorized,
|
||||
},
|
||||
|
||||
{
|
||||
desc: "NotFound",
|
||||
errs: []error{
|
||||
svcerr.ErrNotFound,
|
||||
},
|
||||
code: http.StatusNotFound,
|
||||
desc: "RequestError - Empty List",
|
||||
err: apiutil.ErrEmptyList,
|
||||
code: http.StatusBadRequest,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "Conflict",
|
||||
errs: []error{
|
||||
svcerr.ErrConflict,
|
||||
svcerr.ErrConflict,
|
||||
},
|
||||
code: http.StatusConflict,
|
||||
desc: "RequestError - Conflict",
|
||||
err: svcerr.ErrConflict,
|
||||
code: http.StatusBadRequest,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "Forbidden",
|
||||
errs: []error{
|
||||
svcerr.ErrAuthorization,
|
||||
svcerr.ErrAuthorization,
|
||||
svcerr.ErrDomainAuthorization,
|
||||
},
|
||||
code: http.StatusForbidden,
|
||||
desc: "NotFoundError - Not Found",
|
||||
err: svcerr.ErrNotFound,
|
||||
code: http.StatusNotFound,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "UnsupportedMediaType",
|
||||
errs: []error{
|
||||
apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
code: http.StatusUnsupportedMediaType,
|
||||
desc: "AuthNError - Authentication Failed",
|
||||
err: svcerr.ErrAuthentication,
|
||||
code: http.StatusUnauthorized,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "StatusUnprocessableEntity",
|
||||
errs: []error{
|
||||
svcerr.ErrCreateEntity,
|
||||
svcerr.ErrUpdateEntity,
|
||||
svcerr.ErrRemoveEntity,
|
||||
},
|
||||
code: http.StatusUnprocessableEntity,
|
||||
desc: "AuthZError - Authorization Failed",
|
||||
err: svcerr.ErrAuthorization,
|
||||
code: http.StatusForbidden,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "InternalServerError",
|
||||
errs: []error{
|
||||
errors.New("test"),
|
||||
},
|
||||
code: http.StatusInternalServerError,
|
||||
desc: "AuthZError - Domain Authorization Failed",
|
||||
err: svcerr.ErrDomainAuthorization,
|
||||
code: http.StatusForbidden,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "MediaTypeError - Unsupported Content Type",
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
code: http.StatusUnsupportedMediaType,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "ServiceError - Create Entity Failed",
|
||||
err: svcerr.ErrCreateEntity,
|
||||
code: http.StatusUnprocessableEntity,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "ServiceError - Update Entity Failed",
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
code: http.StatusUnprocessableEntity,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "ServiceError - Remove Entity Failed",
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
code: http.StatusUnprocessableEntity,
|
||||
hasBody: true,
|
||||
},
|
||||
{
|
||||
desc: "InternalError",
|
||||
err: errors.NewInternalError(),
|
||||
code: http.StatusInternalServerError,
|
||||
hasBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
responseWriter := newResponseWriter()
|
||||
for _, err := range c.errs {
|
||||
api.EncodeError(context.Background(), err, responseWriter)
|
||||
assert.Equal(t, c.code, responseWriter.StatusCode())
|
||||
|
||||
message := body{}
|
||||
jerr := json.Unmarshal(responseWriter.Body(), &message)
|
||||
assert.NoError(t, jerr)
|
||||
|
||||
var wrapper error
|
||||
switch errors.Contains(err, apiutil.ErrValidation) {
|
||||
case true:
|
||||
wrapper, err = errors.Unwrap(err)
|
||||
assert.Equal(t, err.Error(), message.Error)
|
||||
assert.Equal(t, wrapper.Error(), message.Message)
|
||||
case false:
|
||||
assert.Equal(t, err.Error(), message.Message)
|
||||
}
|
||||
api.EncodeError(context.Background(), c.err, responseWriter)
|
||||
assert.Equal(t, c.code, responseWriter.StatusCode())
|
||||
if !c.hasBody {
|
||||
return
|
||||
}
|
||||
message := body{}
|
||||
jerr := json.Unmarshal(responseWriter.Body(), &message)
|
||||
assert.NoError(t, jerr)
|
||||
assert.NotEmpty(t, message.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+91
-88
@@ -10,268 +10,271 @@ import "github.com/absmach/supermq/pkg/errors"
|
||||
// errors are logged twice.
|
||||
var (
|
||||
// ErrValidation indicates that an error was returned by the API.
|
||||
ErrValidation = errors.New("something went wrong with the request")
|
||||
ErrValidation = errors.NewRequestError("something went wrong with the request")
|
||||
|
||||
// ErrBearerToken indicates missing or invalid bearer user token.
|
||||
ErrBearerToken = errors.New("missing or invalid bearer user token")
|
||||
ErrBearerToken = errors.NewAuthNError("missing or invalid bearer user token")
|
||||
|
||||
// ErrBearerKey indicates missing or invalid bearer entity key.
|
||||
ErrBearerKey = errors.New("missing or invalid bearer entity key")
|
||||
ErrBearerKey = errors.NewAuthNError("missing or invalid bearer entity key")
|
||||
|
||||
// ErrMissingID indicates missing entity ID.
|
||||
ErrMissingID = errors.New("missing entity id")
|
||||
ErrMissingID = errors.NewRequestError("missing entity id")
|
||||
|
||||
// ErrMissingEntityID indicates missing entity ID.
|
||||
ErrMissingEntityID = errors.New("missing entity id")
|
||||
ErrMissingEntityID = errors.NewRequestError("missing entity id")
|
||||
|
||||
// ErrMissingClientID indicates missing client ID.
|
||||
ErrMissingClientID = errors.New("missing cient id")
|
||||
ErrMissingClientID = errors.NewRequestError("missing client id")
|
||||
|
||||
// ErrMissingChannelID indicates missing client ID.
|
||||
ErrMissingChannelID = errors.New("missing channel id")
|
||||
ErrMissingChannelID = errors.NewRequestError("missing channel id")
|
||||
|
||||
// ErrMissingConnectionType indicates missing connection tpye.
|
||||
ErrMissingConnectionType = errors.New("missing connection type")
|
||||
ErrMissingConnectionType = errors.NewRequestError("missing connection type")
|
||||
|
||||
// ErrMissingParentGroupID indicates missing parent group ID.
|
||||
ErrMissingParentGroupID = errors.New("missing parent group id")
|
||||
ErrMissingParentGroupID = errors.NewRequestError("missing parent group id")
|
||||
|
||||
// ErrMissingChildrenGroupIDs indicates missing children group IDs.
|
||||
ErrMissingChildrenGroupIDs = errors.New("missing children group ids")
|
||||
ErrMissingChildrenGroupIDs = errors.NewRequestError("missing children group ids")
|
||||
|
||||
// ErrSelfParentingNotAllowed indicates child id is same as parent id.
|
||||
ErrSelfParentingNotAllowed = errors.New("self parenting not allowed")
|
||||
ErrSelfParentingNotAllowed = errors.NewRequestError("self parenting not allowed")
|
||||
|
||||
// ErrInvalidChildGroupID indicates invalid child group ID.
|
||||
ErrInvalidChildGroupID = errors.New("invalid child group id")
|
||||
ErrInvalidChildGroupID = errors.NewRequestError("invalid child group id")
|
||||
|
||||
// ErrInvalidAuthKey indicates invalid auth key.
|
||||
ErrInvalidAuthKey = errors.New("invalid auth key")
|
||||
|
||||
// ErrInvalidIDFormat indicates an invalid ID format.
|
||||
ErrInvalidIDFormat = errors.New("invalid id format provided")
|
||||
ErrInvalidIDFormat = errors.NewRequestError("invalid id format provided")
|
||||
|
||||
// ErrNameSize indicates that name size exceeds the max.
|
||||
ErrNameSize = errors.New("invalid name size")
|
||||
ErrNameSize = errors.NewRequestError("invalid name size")
|
||||
|
||||
// ErrEmailSize indicates that email size exceeds the max.
|
||||
ErrEmailSize = errors.New("invalid email size")
|
||||
ErrEmailSize = errors.NewRequestError("invalid email size")
|
||||
|
||||
// ErrInvalidRole indicates that an invalid role.
|
||||
ErrInvalidRole = errors.New("invalid client role")
|
||||
ErrInvalidRole = errors.NewRequestError("invalid client role")
|
||||
|
||||
// ErrLimitSize indicates that an invalid limit.
|
||||
ErrLimitSize = errors.New("invalid limit size")
|
||||
ErrLimitSize = errors.NewRequestError("invalid limit size")
|
||||
|
||||
// ErrLevel indicates that an invalid level.
|
||||
ErrLevel = errors.New("invalid level")
|
||||
ErrLevel = errors.NewRequestError("invalid level")
|
||||
|
||||
// ErrOffsetSize indicates an invalid offset.
|
||||
ErrOffsetSize = errors.New("invalid offset size")
|
||||
ErrOffsetSize = errors.NewRequestError("invalid offset size")
|
||||
|
||||
// ErrInvalidOrder indicates an invalid list order.
|
||||
ErrInvalidOrder = errors.New("invalid list order provided")
|
||||
ErrInvalidOrder = errors.NewRequestError("invalid list order provided")
|
||||
|
||||
// ErrInvalidDirection indicates an invalid list direction.
|
||||
ErrInvalidDirection = errors.New("invalid list direction provided")
|
||||
ErrInvalidDirection = errors.NewRequestError("invalid list direction provided")
|
||||
|
||||
// ErrInvalidMemberKind indicates an invalid member kind.
|
||||
ErrInvalidMemberKind = errors.New("invalid member kind")
|
||||
ErrInvalidMemberKind = errors.NewRequestError("invalid member kind")
|
||||
|
||||
// ErrEmptyList indicates that entity data is empty.
|
||||
ErrEmptyList = errors.New("empty list provided")
|
||||
ErrEmptyList = errors.NewRequestError("empty list provided")
|
||||
|
||||
// ErrMissingRoleName indicates that role name is empty.
|
||||
ErrMissingRoleName = errors.New("empty role name")
|
||||
ErrMissingRoleName = errors.NewRequestError("empty role name")
|
||||
|
||||
// ErrMissingRoleID indicates that role id is empty.
|
||||
ErrMissingRoleID = errors.New("empty role id")
|
||||
ErrMissingRoleID = errors.NewRequestError("empty role id")
|
||||
|
||||
// ErrMissingRoleOperations indicates that role operations are empty.
|
||||
ErrMissingRoleOperations = errors.New("empty role operations")
|
||||
ErrMissingRoleOperations = errors.NewRequestError("empty role operations")
|
||||
|
||||
// ErrMissingRoleMembers indicates that role members are empty.
|
||||
ErrMissingRoleMembers = errors.New("empty role members")
|
||||
ErrMissingRoleMembers = errors.NewRequestError("empty role members")
|
||||
|
||||
// ErrMalformedPolicy indicates that policies are malformed.
|
||||
ErrMalformedPolicy = errors.New("malformed policy")
|
||||
ErrMalformedPolicy = errors.NewRequestError("malformed policy")
|
||||
|
||||
// ErrMissingPolicySub indicates that policies are subject.
|
||||
ErrMissingPolicySub = errors.New("malformed policy subject")
|
||||
ErrMissingPolicySub = errors.NewRequestError("malformed policy subject")
|
||||
|
||||
// ErrMissingPolicyObj indicates missing policies object.
|
||||
ErrMissingPolicyObj = errors.New("malformed policy object")
|
||||
ErrMissingPolicyObj = errors.NewRequestError("malformed policy object")
|
||||
|
||||
// ErrMalformedPolicyAct indicates missing policies action.
|
||||
ErrMalformedPolicyAct = errors.New("malformed policy action")
|
||||
ErrMalformedPolicyAct = errors.NewRequestError("malformed policy action")
|
||||
|
||||
// ErrMissingPolicyEntityType indicates missing policies entity type.
|
||||
ErrMissingPolicyEntityType = errors.New("missing policy entity type")
|
||||
ErrMissingPolicyEntityType = errors.NewRequestError("missing policy entity type")
|
||||
|
||||
// ErrMalformedPolicyPer indicates missing policies relation.
|
||||
ErrMalformedPolicyPer = errors.New("malformed policy permission")
|
||||
ErrMalformedPolicyPer = errors.NewRequestError("malformed policy permission")
|
||||
|
||||
// ErrMissingCertData indicates missing cert data (ttl).
|
||||
ErrMissingCertData = errors.New("missing certificate data")
|
||||
ErrMissingCertData = errors.NewRequestError("missing certificate data")
|
||||
|
||||
// ErrInvalidCertData indicates invalid cert data (ttl).
|
||||
ErrInvalidCertData = errors.New("invalid certificate data")
|
||||
ErrInvalidCertData = errors.NewRequestError("invalid certificate data")
|
||||
|
||||
// ErrInvalidTopic indicates an invalid subscription topic.
|
||||
ErrInvalidTopic = errors.New("invalid Subscription topic")
|
||||
ErrInvalidTopic = errors.NewRequestError("invalid Subscription topic")
|
||||
|
||||
// ErrInvalidContact indicates an invalid subscription contract.
|
||||
ErrInvalidContact = errors.New("invalid Subscription contact")
|
||||
ErrInvalidContact = errors.NewRequestError("invalid Subscription contact")
|
||||
|
||||
// ErrMissingEmail indicates missing email.
|
||||
ErrMissingEmail = errors.New("missing email")
|
||||
ErrMissingEmail = errors.NewRequestError("missing email")
|
||||
|
||||
// ErrInvalidEmail indicates missing email.
|
||||
ErrInvalidEmail = errors.New("invalid email")
|
||||
ErrInvalidEmail = errors.NewRequestError("invalid email")
|
||||
|
||||
// ErrMissingHost indicates missing host.
|
||||
ErrMissingHost = errors.New("missing host")
|
||||
ErrMissingHost = errors.NewRequestError("missing host")
|
||||
|
||||
// ErrMissingPass indicates missing password.
|
||||
ErrMissingPass = errors.New("missing password")
|
||||
ErrMissingPass = errors.NewRequestError("missing password")
|
||||
|
||||
// ErrMissingConfPass indicates missing conf password.
|
||||
ErrMissingConfPass = errors.New("missing conf password")
|
||||
ErrMissingConfPass = errors.NewRequestError("missing conf password")
|
||||
|
||||
// ErrInvalidResetPass indicates an invalid reset password.
|
||||
ErrInvalidResetPass = errors.New("invalid reset password")
|
||||
ErrInvalidResetPass = errors.NewRequestError("invalid reset password")
|
||||
|
||||
// ErrInvalidComparator indicates an invalid comparator.
|
||||
ErrInvalidComparator = errors.New("invalid comparator")
|
||||
ErrInvalidComparator = errors.NewRequestError("invalid comparator")
|
||||
|
||||
// ErrMissingMemberIDs indicates missing member ids.
|
||||
ErrMissingMemberIDs = errors.New("missing member ids")
|
||||
ErrMissingMemberIDs = errors.NewRequestError("missing member ids")
|
||||
|
||||
// ErrMissingMemberType indicates missing group member type.
|
||||
ErrMissingMemberType = errors.New("missing group member type")
|
||||
ErrMissingMemberType = errors.NewRequestError("missing group member type")
|
||||
|
||||
// ErrMissingMemberKind indicates missing group member kind.
|
||||
ErrMissingMemberKind = errors.New("missing group member kind")
|
||||
ErrMissingMemberKind = errors.NewRequestError("missing group member kind")
|
||||
|
||||
// ErrMissingRelation indicates missing relation.
|
||||
ErrMissingRelation = errors.New("missing relation")
|
||||
ErrMissingRelation = errors.NewRequestError("missing relation")
|
||||
|
||||
// ErrInvalidRelation indicates an invalid relation.
|
||||
ErrInvalidRelation = errors.New("invalid relation")
|
||||
ErrInvalidRelation = errors.NewRequestError("invalid relation")
|
||||
|
||||
// ErrInvalidAPIKey indicates an invalid API key type.
|
||||
ErrInvalidAPIKey = errors.New("invalid api key type")
|
||||
ErrInvalidAPIKey = errors.NewRequestError("invalid api key type")
|
||||
|
||||
// ErrInvitationState indicates an invalid invitation state.
|
||||
ErrInvitationState = errors.New("invalid invitation state")
|
||||
ErrInvitationState = errors.NewRequestError("invalid invitation state")
|
||||
|
||||
// ErrMissingIdentity indicates missing entity Identity.
|
||||
ErrMissingIdentity = errors.New("missing entity identity")
|
||||
ErrMissingIdentity = errors.NewRequestError("missing entity identity")
|
||||
|
||||
// ErrMissingSecret indicates missing secret.
|
||||
ErrMissingSecret = errors.New("missing secret")
|
||||
ErrMissingSecret = errors.NewRequestError("missing secret")
|
||||
|
||||
// ErrPasswordFormat indicates weak password.
|
||||
ErrPasswordFormat = errors.New("password does not meet the requirements")
|
||||
ErrPasswordFormat = errors.NewRequestError("password does not meet the requirements")
|
||||
|
||||
// ErrMissingName indicates missing identity name.
|
||||
ErrMissingName = errors.New("missing identity name")
|
||||
ErrMissingName = errors.NewRequestError("missing identity name")
|
||||
|
||||
// ErrMissingRoute indicates missing route.
|
||||
ErrMissingRoute = errors.New("missing route")
|
||||
ErrMissingRoute = errors.NewRequestError("missing route")
|
||||
|
||||
// ErrInvalidLevel indicates an invalid group level.
|
||||
ErrInvalidLevel = errors.New("invalid group level (should be between 0 and 5)")
|
||||
ErrInvalidLevel = errors.NewRequestError("invalid group level (should be between 0 and 5)")
|
||||
|
||||
// ErrNotFoundParam indicates that the parameter was not found in the query.
|
||||
ErrNotFoundParam = errors.New("parameter not found in the query")
|
||||
ErrNotFoundParam = errors.NewRequestError("parameter not found in the query")
|
||||
|
||||
// ErrInvalidQueryParams indicates invalid query parameters.
|
||||
ErrInvalidQueryParams = errors.New("invalid query parameters")
|
||||
ErrInvalidQueryParams = errors.NewRequestError("invalid query parameters")
|
||||
|
||||
// ErrInvalidVisibilityType indicates invalid visibility type.
|
||||
ErrInvalidVisibilityType = errors.New("invalid visibility type")
|
||||
ErrInvalidVisibilityType = errors.NewRequestError("invalid visibility type")
|
||||
|
||||
// ErrUnsupportedContentType indicates unacceptable or lack of Content-Type.
|
||||
ErrUnsupportedContentType = errors.New("unsupported content type")
|
||||
ErrUnsupportedContentType = errors.NewMediaTypeError("unsupported content type")
|
||||
|
||||
// ErrRollbackTx indicates failed to rollback transaction.
|
||||
ErrRollbackTx = errors.New("failed to rollback transaction")
|
||||
ErrRollbackTx = errors.NewRequestError("failed to rollback transaction")
|
||||
|
||||
// ErrInvalidAggregation indicates invalid aggregation value.
|
||||
ErrInvalidAggregation = errors.New("invalid aggregation value")
|
||||
ErrInvalidAggregation = errors.NewRequestError("invalid aggregation value")
|
||||
|
||||
// ErrInvalidInterval indicates invalid interval value.
|
||||
ErrInvalidInterval = errors.New("invalid interval value")
|
||||
ErrInvalidInterval = errors.NewRequestError("invalid interval value")
|
||||
|
||||
// ErrMissingFrom indicates missing from value.
|
||||
ErrMissingFrom = errors.New("missing from time value")
|
||||
ErrMissingFrom = errors.NewRequestError("missing from time value")
|
||||
|
||||
// ErrMissingTo indicates missing to value.
|
||||
ErrMissingTo = errors.New("missing to time value")
|
||||
ErrMissingTo = errors.NewRequestError("missing to time value")
|
||||
|
||||
// ErrEmptyMessage indicates empty message.
|
||||
ErrEmptyMessage = errors.New("empty message")
|
||||
ErrEmptyMessage = errors.NewRequestError("empty message")
|
||||
|
||||
// ErrMissingEntityType indicates missing entity type.
|
||||
ErrMissingEntityType = errors.New("missing entity type")
|
||||
ErrMissingEntityType = errors.NewRequestError("missing entity type")
|
||||
|
||||
// ErrInvalidEntityType indicates invalid entity type.
|
||||
ErrInvalidEntityType = errors.New("invalid entity type")
|
||||
ErrInvalidEntityType = errors.NewRequestError("invalid entity type")
|
||||
|
||||
// ErrInvalidTimeFormat indicates invalid time format i.e not unix time.
|
||||
ErrInvalidTimeFormat = errors.New("invalid time format use unix time")
|
||||
ErrInvalidTimeFormat = errors.NewRequestError("invalid time format use unix time")
|
||||
|
||||
// ErrEmptySearchQuery indicates search query should not be empty.
|
||||
ErrEmptySearchQuery = errors.New("search query must not be empty")
|
||||
ErrEmptySearchQuery = errors.NewRequestError("search query must not be empty")
|
||||
|
||||
// ErrLenSearchQuery indicates search query length.
|
||||
ErrLenSearchQuery = errors.New("search query must be at least 3 characters")
|
||||
ErrLenSearchQuery = errors.NewRequestError("search query must be at least 3 characters")
|
||||
|
||||
// ErrMissingDomainID indicates missing domainID.
|
||||
ErrMissingDomainID = errors.New("missing domainID")
|
||||
ErrMissingDomainID = errors.NewRequestError("missing domainID")
|
||||
|
||||
// ErrMissingUsername indicates missing user name.
|
||||
ErrMissingUsername = errors.New("missing username")
|
||||
ErrMissingUsername = errors.NewRequestError("missing username")
|
||||
|
||||
// ErrInvalidUsername indicates invalid user name.
|
||||
ErrInvalidUsername = errors.New("invalid username")
|
||||
ErrInvalidUsername = errors.NewRequestError("invalid username")
|
||||
|
||||
// ErrMissingFirstName indicates missing first name.
|
||||
ErrMissingFirstName = errors.New("missing first name")
|
||||
ErrMissingFirstName = errors.NewRequestError("missing first name")
|
||||
|
||||
// ErrMissingLastName indicates missing last name.
|
||||
ErrMissingLastName = errors.New("missing last name")
|
||||
ErrMissingLastName = errors.NewRequestError("missing last name")
|
||||
|
||||
// ErrInvalidProfilePictureURL indicates that the profile picture url is invalid.
|
||||
ErrInvalidProfilePictureURL = errors.New("invalid profile picture url")
|
||||
ErrInvalidProfilePictureURL = errors.NewRequestError("invalid profile picture url")
|
||||
|
||||
ErrMultipleEntitiesFilter = errors.New("multiple entities are provided in filter are not supported")
|
||||
ErrMultipleEntitiesFilter = errors.NewRequestError("multiple entities are provided in filter are not supported")
|
||||
|
||||
// ErrMissingDescription indicates missing description.
|
||||
ErrMissingDescription = errors.New("missing description")
|
||||
ErrMissingDescription = errors.NewRequestError("missing description")
|
||||
|
||||
// ErrUnsupportedTokenType indicates that this type of token is not supported.
|
||||
ErrUnsupportedTokenType = errors.New("unsupported content token type")
|
||||
ErrUnsupportedTokenType = errors.NewRequestError("unsupported content token type")
|
||||
|
||||
// ErrMissingUserID indicates missing user ID.
|
||||
ErrMissingUserID = errors.New("missing user id")
|
||||
ErrMissingUserID = errors.NewRequestError("missing user id")
|
||||
|
||||
// ErrMissingPATID indicates missing pat ID.
|
||||
ErrMissingPATID = errors.New("missing pat id")
|
||||
ErrMissingPATID = errors.NewRequestError("missing pat id")
|
||||
|
||||
// ErrInvalidNameFormat indicates invalid name format.
|
||||
ErrInvalidNameFormat = errors.New("invalid name format")
|
||||
ErrInvalidNameFormat = errors.NewRequestError("invalid name format")
|
||||
|
||||
// ErrInvalidRouteFormat indicates invalid route format.
|
||||
ErrInvalidRouteFormat = errors.New("invalid route format")
|
||||
ErrInvalidRouteFormat = errors.NewRequestError("invalid route format")
|
||||
|
||||
// ErrMissingUsernameEmail indicates missing user name / email.
|
||||
ErrMissingUsernameEmail = errors.New("missing username / email")
|
||||
ErrMissingUsernameEmail = errors.NewRequestError("missing username / email")
|
||||
|
||||
// ErrInvalidVerification indicates invalid email verification.
|
||||
ErrInvalidVerification = errors.New("invalid verification")
|
||||
ErrInvalidVerification = errors.NewRequestError("invalid verification")
|
||||
|
||||
// ErrEmailNotVerified indicates invalid email not verified.
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
ErrEmailNotVerified = errors.NewRequestError("email not verified")
|
||||
|
||||
// ErrMalformedRequest indicates malformed request body.
|
||||
ErrMalformedRequestBody = errors.NewRequestError("request body is not a valid JSON, expecting a valid JSON")
|
||||
)
|
||||
|
||||
@@ -253,7 +253,7 @@ func TestRetrieve(t *testing.T) {
|
||||
desc: "retrieve a non-existing key",
|
||||
id: "non-existing",
|
||||
token: token.AccessToken,
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -57,7 +57,7 @@ func decodeIssue(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := issueKeyReq{token: apiutil.ExtractBearerToken(r)}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -140,7 +140,7 @@ func decodeCreatePATRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
}
|
||||
req := createPatReq{token: token}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func decodeUpdatePATNameRequest(_ context.Context, r *http.Request) (any, error)
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func decodeUpdatePATDescriptionRequest(_ context.Context, r *http.Request) (any,
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -262,7 +262,7 @@ func decodeResetPATSecretRequest(_ context.Context, r *http.Request) (any, error
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -305,7 +305,7 @@ func decodeAddScopeRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -348,7 +348,7 @@ func decodeRemoveScopeRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -26,10 +26,7 @@ const (
|
||||
secret = "test"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidIssuer = errors.New("invalid token issuer value")
|
||||
reposecret = []byte("test")
|
||||
)
|
||||
var reposecret = []byte("test")
|
||||
|
||||
func newToken(issuerName string, key auth.Key) string {
|
||||
builder := jwt.NewBuilder()
|
||||
@@ -194,7 +191,7 @@ func TestParse(t *testing.T) {
|
||||
desc: "parse token with invalid issuer",
|
||||
key: auth.Key{},
|
||||
token: inValidToken,
|
||||
err: errInvalidIssuer,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse token with invalid content",
|
||||
@@ -212,7 +209,7 @@ func TestParse(t *testing.T) {
|
||||
desc: "parse token with empty type",
|
||||
key: emptyTypeKey,
|
||||
token: emptyTypeToken,
|
||||
err: errors.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+5
-5
@@ -40,11 +40,11 @@ var (
|
||||
errMalformedPAT = errors.New("malformed personal access token")
|
||||
errFailedToParseUUID = errors.New("failed to parse string to UUID")
|
||||
errInvalidLenFor2UUIDs = errors.New("invalid input length for 2 UUID, excepted 32 byte")
|
||||
errRevokedPAT = errors.New("revoked pat")
|
||||
errCreatePAT = errors.New("failed to create PAT")
|
||||
errUpdatePAT = errors.New("failed to update PAT")
|
||||
errRetrievePAT = errors.New("failed to retrieve PAT")
|
||||
errDeletePAT = errors.New("failed to delete PAT")
|
||||
errRevokedPAT = errors.NewServiceError("revoked pat")
|
||||
errCreatePAT = errors.NewServiceError("failed to create PAT")
|
||||
errUpdatePAT = errors.NewServiceError("failed to update PAT")
|
||||
errRetrievePAT = errors.NewServiceError("failed to retrieve PAT")
|
||||
errDeletePAT = errors.NewServiceError("failed to delete PAT")
|
||||
errInvalidScope = errors.New("invalid scope")
|
||||
)
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func decodeCreateChannelReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := createChannelReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Channel); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -51,7 +51,7 @@ func decodeCreateChannelsReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := createChannelsReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Channels); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -188,7 +188,7 @@ func decodeUpdateChannel(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -203,7 +203,7 @@ func decodeUpdateChannelTags(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -218,7 +218,7 @@ func decodeSetChannelParentGroupStatus(_ context.Context, r *http.Request) (any,
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -254,7 +254,7 @@ func decodeConnectChannelClientRequest(_ context.Context, r *http.Request) (any,
|
||||
channelID: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -268,7 +268,7 @@ func decodeDisconnectChannelClientsRequest(_ context.Context, r *http.Request) (
|
||||
channelID: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -281,7 +281,7 @@ func decodeConnectRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := connectRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -294,7 +294,7 @@ func decodeDisconnectRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := disconnectRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -574,7 +574,7 @@ func TestListChannels(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list channels with limit",
|
||||
@@ -596,7 +596,7 @@ func TestListChannels(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "limit=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list channels with limit greater than max",
|
||||
@@ -604,7 +604,7 @@ func TestListChannels(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "list channels with name",
|
||||
@@ -656,7 +656,7 @@ func TestListChannels(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "status=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: svcerr.ErrInvalidStatus,
|
||||
},
|
||||
{
|
||||
desc: "list channels with duplicate status",
|
||||
@@ -716,7 +716,7 @@ func TestListChannels(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "metadata=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list channels with duplicate metadata",
|
||||
@@ -1114,7 +1114,7 @@ func TestUpdateChannelTagsEndpoint(t *testing.T) {
|
||||
contentType: contentType,
|
||||
data: fmt.Sprintf(`{"tags":["%s"}`, newTag),
|
||||
status: http.StatusBadRequest,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update channel with empty id",
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
@@ -30,14 +29,13 @@ const (
|
||||
rolesTableNamePrefix = "channels"
|
||||
entityTableName = "channels"
|
||||
entityIDColumnName = "id"
|
||||
|
||||
pgDuplicateErrCode = "23505"
|
||||
)
|
||||
|
||||
var _ channels.Repository = (*channelRepository)(nil)
|
||||
|
||||
type channelRepository struct {
|
||||
db postgres.Database
|
||||
eh errors.Handler
|
||||
rolesPostgres.Repository
|
||||
}
|
||||
|
||||
@@ -45,8 +43,12 @@ type channelRepository struct {
|
||||
// repository.
|
||||
func NewRepository(db postgres.Database) channels.Repository {
|
||||
rolesRepo := rolesPostgres.NewRepository(db, policies.ChannelType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &channelRepository{
|
||||
db: db,
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
Repository: rolesRepo,
|
||||
}
|
||||
}
|
||||
@@ -67,7 +69,7 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbchs)
|
||||
if err != nil {
|
||||
return []channels.Channel{}, handleSaveError(repoerr.ErrCreateEntity, err)
|
||||
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
@@ -77,7 +79,7 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
|
||||
for row.Next() {
|
||||
dbch := dbChannel{}
|
||||
if err := row.StructScan(&dbch); err != nil {
|
||||
return []channels.Channel{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
ch, err := toChannel(dbch)
|
||||
@@ -89,16 +91,6 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
|
||||
return reChs, nil
|
||||
}
|
||||
|
||||
func handleSaveError(wrapper, err error) error {
|
||||
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
|
||||
switch pqErr.ConstraintName {
|
||||
case "unique_domain_route_not_null":
|
||||
return errors.ErrRouteNotAvailable
|
||||
}
|
||||
}
|
||||
return postgres.HandleError(wrapper, err)
|
||||
}
|
||||
|
||||
func (cr *channelRepository) Update(ctx context.Context, channel channels.Channel) (channels.Channel, error) {
|
||||
var query []string
|
||||
var upq string
|
||||
@@ -144,14 +136,14 @@ func (cr *channelRepository) RetrieveByID(ctx context.Context, id string) (chann
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbch)
|
||||
if err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbch = dbChannel{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbch); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return toChannel(dbch)
|
||||
}
|
||||
@@ -170,14 +162,14 @@ func (cr *channelRepository) RetrieveByRoute(ctx context.Context, route, domainI
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbch)
|
||||
if err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbch = dbChannel{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbch); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return toChannel(dbch)
|
||||
}
|
||||
@@ -368,17 +360,17 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
|
||||
}
|
||||
row, err := cr.db.NamedQueryContext(ctx, query, parameters)
|
||||
if err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbch := dbChannel{}
|
||||
if !row.Next() {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrNotFound, err)
|
||||
return channels.Channel{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbch); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return toChannel(dbch)
|
||||
@@ -450,14 +442,14 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := cr.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbch := dbChannel{}
|
||||
if err := rows.StructScan(&dbch); err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
ch, err := toChannel(dbch)
|
||||
@@ -476,7 +468,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
|
||||
|
||||
total, err := postgres.Total(ctx, cr.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := channels.ChannelsPage{
|
||||
@@ -566,14 +558,14 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbc := dbChannel{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := toChannel(dbc)
|
||||
@@ -617,7 +609,7 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := channels.ChannelsPage{
|
||||
@@ -912,7 +904,7 @@ func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
|
||||
}
|
||||
result, err := cr.db.NamedExecContext(ctx, q, params)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -928,7 +920,7 @@ func (cr *channelRepository) SetParentGroup(ctx context.Context, ch channels.Cha
|
||||
}
|
||||
result, err := cr.db.NamedExecContext(ctx, q, dbCh)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -944,7 +936,7 @@ func (cr *channelRepository) RemoveParentGroup(ctx context.Context, ch channels.
|
||||
}
|
||||
result, err := cr.db.NamedExecContext(ctx, q, dbCh)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -958,7 +950,7 @@ func (cr *channelRepository) AddConnections(ctx context.Context, conns []channel
|
||||
VALUES (:channel_id, :domain_id, :client_id, :type );`
|
||||
|
||||
if _, err := cr.db.NamedExecContext(ctx, q, dbConns); err != nil {
|
||||
return postgres.HandleError(repoerr.ErrCreateEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -967,7 +959,7 @@ func (cr *channelRepository) AddConnections(ctx context.Context, conns []channel
|
||||
func (cr *channelRepository) RemoveConnections(ctx context.Context, conns []channels.Connection) (retErr error) {
|
||||
tx, err := cr.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
@@ -985,11 +977,11 @@ func (cr *channelRepository) RemoveConnections(ctx context.Context, conns []chan
|
||||
}
|
||||
dbConn := toDBConnection(conn)
|
||||
if _, err := tx.NamedExec(query, dbConn); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, errors.Wrap(fmt.Errorf("failed to delete connection for channel_id: %s, domain_id: %s client_id %s", conn.ChannelID, conn.DomainID, conn.ClientID), err))
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, errors.Wrap(fmt.Errorf("failed to delete connection for channel_id: %s, domain_id: %s client_id %s", conn.ChannelID, conn.DomainID, conn.ClientID), err))
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -999,7 +991,7 @@ func (cr *channelRepository) CheckConnection(ctx context.Context, conn channels.
|
||||
dbConn := toDBConnection(conn)
|
||||
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1014,7 +1006,7 @@ func (cr *channelRepository) ClientAuthorize(ctx context.Context, conn channels.
|
||||
dbConn := toDBConnection(conn)
|
||||
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1030,7 +1022,7 @@ func (cr *channelRepository) ChannelConnectionsCount(ctx context.Context, id str
|
||||
|
||||
total, err := postgres.Total(ctx, cr.db, query, dbConn)
|
||||
if err != nil {
|
||||
return 0, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return 0, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
@@ -1041,7 +1033,7 @@ func (cr *channelRepository) DoesChannelHaveConnections(ctx context.Context, id
|
||||
|
||||
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
|
||||
if err != nil {
|
||||
return false, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return false, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1053,7 +1045,7 @@ func (cr *channelRepository) RemoveClientConnections(ctx context.Context, client
|
||||
|
||||
dbConn := dbConnection{ClientID: clientID}
|
||||
if _, err := cr.db.NamedExecContext(ctx, query, dbConn); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1063,7 +1055,7 @@ func (cr *channelRepository) RemoveChannelConnections(ctx context.Context, chann
|
||||
|
||||
dbConn := dbConnection{ChannelID: channelID}
|
||||
if _, err := cr.db.NamedExecContext(ctx, query, dbConn); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1074,7 +1066,7 @@ func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, pa
|
||||
|
||||
rows, err := cr.db.NamedQueryContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)})
|
||||
if err != nil {
|
||||
return []channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -1082,7 +1074,7 @@ func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, pa
|
||||
for rows.Next() {
|
||||
dbch := dbChannel{}
|
||||
if err := rows.StructScan(&dbch); err != nil {
|
||||
return []channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
ch, err := toChannel(dbch)
|
||||
@@ -1099,7 +1091,7 @@ func (cr *channelRepository) UnsetParentGroupFromChannels(ctx context.Context, p
|
||||
query := "UPDATE channels SET parent_group_id = NULL WHERE parent_group_id = :parent_group_id"
|
||||
|
||||
if _, err := cr.db.NamedExecContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)}); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1112,14 +1104,14 @@ func (cr *channelRepository) update(ctx context.Context, ch channels.Channel, qu
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, query, dbch)
|
||||
if err != nil {
|
||||
return channels.Channel{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbch = dbChannel{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbch); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return toChannel(dbch)
|
||||
|
||||
@@ -102,6 +102,7 @@ var (
|
||||
"subgroup_set_parent",
|
||||
"subgroup_update",
|
||||
}
|
||||
errChannelExists = errors.New("channel id already exists")
|
||||
)
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
@@ -143,7 +144,7 @@ func TestSave(t *testing.T) {
|
||||
desc: "add duplicate channel",
|
||||
channel: validChannel,
|
||||
resp: []channels.Channel{},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errChannelExists,
|
||||
},
|
||||
{
|
||||
desc: "add channel with invalid ID",
|
||||
@@ -156,7 +157,7 @@ func TestSave(t *testing.T) {
|
||||
Status: channels.EnabledStatus,
|
||||
},
|
||||
resp: []channels.Channel{},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add channel with invalid domain",
|
||||
@@ -169,7 +170,7 @@ func TestSave(t *testing.T) {
|
||||
Status: channels.EnabledStatus,
|
||||
},
|
||||
resp: []channels.Channel{},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add channel with invalid name",
|
||||
@@ -182,7 +183,7 @@ func TestSave(t *testing.T) {
|
||||
Status: channels.EnabledStatus,
|
||||
},
|
||||
resp: []channels.Channel{},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add channel with invalid metadata",
|
||||
@@ -197,7 +198,7 @@ func TestSave(t *testing.T) {
|
||||
Status: channels.EnabledStatus,
|
||||
},
|
||||
resp: []channels.Channel{},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add channel with duplicate name",
|
||||
@@ -1086,7 +1087,7 @@ func TestSetParentGroup(t *testing.T) {
|
||||
desc: "set parent group with invalid parent group ID",
|
||||
id: validChannel.ID,
|
||||
parentGroupID: invalidID,
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1208,7 +1209,7 @@ func TestAddConnection(t *testing.T) {
|
||||
DomainID: testsutil.GenerateUUID(t),
|
||||
Type: connections.Publish,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add connection with invalid channel ID",
|
||||
@@ -1218,7 +1219,7 @@ func TestAddConnection(t *testing.T) {
|
||||
DomainID: testsutil.GenerateUUID(t),
|
||||
Type: connections.Publish,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add connection with invalid domain ID",
|
||||
@@ -1228,7 +1229,7 @@ func TestAddConnection(t *testing.T) {
|
||||
DomainID: invalidID,
|
||||
Type: connections.Publish,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "unique_domain_route_not_null":
|
||||
return errors.ErrRouteNotAvailable, true
|
||||
case "channels_pkey":
|
||||
return errors.NewRequestError("channel id already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+1
-1
@@ -501,7 +501,7 @@ func (svc service) changeChannelStatus(ctx context.Context, userID string, chann
|
||||
return Channel{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if dbchannel.Status == channel.Status {
|
||||
return Channel{}, errors.ErrStatusAlreadyAssigned
|
||||
return Channel{}, svcerr.ErrStatusAlreadyAssigned
|
||||
}
|
||||
|
||||
channel.UpdatedBy = userID
|
||||
|
||||
@@ -63,10 +63,9 @@ var (
|
||||
},
|
||||
},
|
||||
}
|
||||
parentGroupID = testsutil.GenerateUUID(&testing.T{})
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
|
||||
errRollbackRoles = errors.New("failed to rollback roles")
|
||||
parentGroupID = testsutil.GenerateUUID(&testing.T{})
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -203,7 +202,7 @@ func TestCreateChannel(t *testing.T) {
|
||||
},
|
||||
addRoleErr: svcerr.ErrCreateEntity,
|
||||
deletePoliciesErr: svcerr.ErrRemoveEntity,
|
||||
err: errRollbackRoles,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -404,7 +403,7 @@ func TestEnableChannel(t *testing.T) {
|
||||
retrieveResp: channels.Channel{
|
||||
Status: channels.EnabledStatus,
|
||||
},
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "enable channel with retrieve error",
|
||||
@@ -467,7 +466,7 @@ func TestDisableChannel(t *testing.T) {
|
||||
retrieveResp: channels.Channel{
|
||||
Status: channels.DisabledStatus,
|
||||
},
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "disable channel with retrieve error",
|
||||
|
||||
@@ -170,7 +170,7 @@ func decodeUpdateClient(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, clientID),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -185,7 +185,7 @@ func decodeUpdateClientTags(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, clientID),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -200,7 +200,7 @@ func decodeUpdateClientCredentials(_ context.Context, r *http.Request) (any, err
|
||||
id: chi.URLParam(r, clientID),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -213,7 +213,7 @@ func decodeCreateClientReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
var c clients.Client
|
||||
if err := json.NewDecoder(r.Body).Decode(&c); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
req := createClientReq{
|
||||
client: c,
|
||||
@@ -229,7 +229,7 @@ func decodeCreateClientsReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
c := createClientsReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&c.Clients); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
@@ -252,7 +252,7 @@ func decodeSetClientParentGroupStatus(_ context.Context, r *http.Request) (any,
|
||||
id: chi.URLParam(r, clientID),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ func TestCreateClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusConflict,
|
||||
status: http.StatusBadRequest,
|
||||
err: svcerr.ErrConflict,
|
||||
},
|
||||
{
|
||||
@@ -161,7 +161,7 @@ func TestCreateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
{
|
||||
desc: "register a client that can't be marshalled",
|
||||
@@ -179,7 +179,7 @@ func TestCreateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "register client with invalid status",
|
||||
@@ -212,7 +212,7 @@ func TestCreateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ func TestCreateClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrEmptyList,
|
||||
len: 0,
|
||||
},
|
||||
{
|
||||
@@ -337,7 +337,7 @@ func TestCreateClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
{
|
||||
desc: "create clients with invalid contentype",
|
||||
@@ -351,7 +351,7 @@ func TestCreateClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "create a client that can't be marshalled",
|
||||
@@ -372,7 +372,7 @@ func TestCreateClients(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "create clients with service error",
|
||||
@@ -499,7 +499,7 @@ func TestListClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
|
||||
query: "offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list clients with limit",
|
||||
@@ -524,7 +524,7 @@ func TestListClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
|
||||
query: "limit=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list clients with limit greater than max",
|
||||
@@ -533,7 +533,7 @@ func TestListClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
|
||||
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "list clients with name",
|
||||
@@ -590,7 +590,7 @@ func TestListClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
|
||||
query: "status=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: svcerr.ErrInvalidStatus,
|
||||
},
|
||||
{
|
||||
desc: "list clients with duplicate status",
|
||||
@@ -656,7 +656,7 @@ func TestListClients(t *testing.T) {
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
|
||||
query: "metadata=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list clients with duplicate metadata",
|
||||
@@ -913,8 +913,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update client with malformed data",
|
||||
@@ -925,8 +924,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update client with empty id",
|
||||
@@ -937,8 +935,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrMissingID,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "update client with name that is too long",
|
||||
@@ -1020,8 +1017,7 @@ func TestUpdateClientsTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusOK,
|
||||
|
||||
err: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update client tags with empty token",
|
||||
@@ -1053,8 +1049,7 @@ func TestUpdateClientsTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusForbidden,
|
||||
|
||||
err: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "update client tags with invalid contentype",
|
||||
@@ -1065,7 +1060,7 @@ func TestUpdateClientsTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update clients tags with empty id",
|
||||
@@ -1076,8 +1071,7 @@ func TestUpdateClientsTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "update clients with malfomed data",
|
||||
@@ -1088,8 +1082,7 @@ func TestUpdateClientsTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1203,7 +1196,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "update client secret with empty secret",
|
||||
@@ -1220,8 +1213,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingSecret,
|
||||
},
|
||||
{
|
||||
desc: "update client secret with invalid contentype",
|
||||
@@ -1238,8 +1230,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update client secret with malformed data",
|
||||
@@ -1256,8 +1247,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1316,8 +1306,7 @@ func TestEnableClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusOK,
|
||||
|
||||
err: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "enable client with invalid token",
|
||||
@@ -1337,8 +1326,7 @@ func TestEnableClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1401,8 +1389,7 @@ func TestDisableClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusOK,
|
||||
|
||||
err: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disable client with invalid token",
|
||||
@@ -1422,8 +1409,7 @@ func TestDisableClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusBadRequest,
|
||||
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1481,8 +1467,7 @@ func TestDeleteClient(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
|
||||
status: http.StatusNoContent,
|
||||
|
||||
err: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "delete client with invalid token",
|
||||
|
||||
+26
-22
@@ -36,6 +36,7 @@ var _ clients.Repository = (*clientRepo)(nil)
|
||||
|
||||
type clientRepo struct {
|
||||
DB postgres.Database
|
||||
eh errors.Handler
|
||||
rolesPostgres.Repository
|
||||
}
|
||||
|
||||
@@ -43,9 +44,12 @@ type clientRepo struct {
|
||||
// implementation of Clients repository.
|
||||
func NewRepository(db postgres.Database) clients.Repository {
|
||||
repo := rolesPostgres.NewRepository(db, policies.ClientType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &clientRepo{
|
||||
DB: db,
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
Repository: repo,
|
||||
}
|
||||
}
|
||||
@@ -66,7 +70,7 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
|
||||
|
||||
row, err := repo.DB.NamedQueryContext(ctx, q, dbClients)
|
||||
if err != nil {
|
||||
return []clients.Client{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
|
||||
return []clients.Client{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
@@ -75,7 +79,7 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
|
||||
for row.Next() {
|
||||
dbcli := DBClient{}
|
||||
if err := row.StructScan(&dbcli); err != nil {
|
||||
return []clients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return []clients.Client{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
client, err := ToClient(dbcli)
|
||||
@@ -108,14 +112,14 @@ func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key, id string, pr
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbc)
|
||||
if err != nil {
|
||||
return clients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
dbc = DBClient{}
|
||||
if rows.Next() {
|
||||
if err = rows.StructScan(&dbc); err != nil {
|
||||
return clients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
client, err := ToClient(dbc)
|
||||
@@ -365,17 +369,17 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
|
||||
}
|
||||
row, err := repo.DB.NamedQueryContext(ctx, query, parameters)
|
||||
if err != nil {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbc := DBClient{}
|
||||
if !row.Next() {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
|
||||
return clients.Client{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbc); err != nil {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return ToClient(dbc)
|
||||
@@ -391,14 +395,14 @@ func (repo *clientRepo) RetrieveByID(ctx context.Context, id string) (clients.Cl
|
||||
|
||||
row, err := repo.DB.NamedQueryContext(ctx, q, dbc)
|
||||
if err != nil {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbc = DBClient{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbc); err != nil {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return ToClient(dbc)
|
||||
@@ -471,14 +475,14 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
@@ -497,7 +501,7 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := clients.ClientsPage{
|
||||
@@ -588,14 +592,14 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
@@ -639,7 +643,7 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
|
||||
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := clients.ClientsPage{
|
||||
@@ -945,7 +949,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
|
||||
|
||||
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -953,7 +957,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
|
||||
for rows.Next() {
|
||||
dbc := DBClient{}
|
||||
if err := rows.StructScan(&dbc); err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToClient(dbc)
|
||||
@@ -967,7 +971,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
|
||||
cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, tq)
|
||||
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := clients.ClientsPage{
|
||||
@@ -990,14 +994,14 @@ func (repo *clientRepo) update(ctx context.Context, client clients.Client, query
|
||||
|
||||
row, err := repo.DB.NamedQueryContext(ctx, query, dbc)
|
||||
if err != nil {
|
||||
return clients.Client{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbc = DBClient{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbc); err != nil {
|
||||
return clients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return clients.Client{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return ToClient(dbc)
|
||||
@@ -1014,7 +1018,7 @@ func (repo *clientRepo) Delete(ctx context.Context, clientIDs ...string) error {
|
||||
}
|
||||
result, err := repo.DB.NamedExecContext(ctx, q, params)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
|
||||
@@ -101,6 +101,7 @@ var (
|
||||
"subgroup_set_parent",
|
||||
"subgroup_update",
|
||||
}
|
||||
errClientSecretNotAvailable = errors.New("client key is not available")
|
||||
)
|
||||
|
||||
func TestClientsSave(t *testing.T) {
|
||||
@@ -189,7 +190,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errClientSecretNotAvailable,
|
||||
},
|
||||
{
|
||||
desc: "add multiple clients with one client having duplicate secret",
|
||||
@@ -216,7 +217,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errClientSecretNotAvailable,
|
||||
},
|
||||
{
|
||||
desc: "add new client without domain id",
|
||||
@@ -249,7 +250,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add multiple clients with one client having invalid client id",
|
||||
@@ -275,7 +276,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add client with invalid client name",
|
||||
@@ -292,7 +293,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add client with invalid client domain id",
|
||||
@@ -308,7 +309,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add client with invalid client identity",
|
||||
@@ -324,7 +325,7 @@ func TestClientsSave(t *testing.T) {
|
||||
Status: clients.EnabledStatus,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add client with a missing client identity",
|
||||
@@ -390,14 +391,16 @@ func TestClientsSave(t *testing.T) {
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
rClients, err := repo.Save(context.Background(), tc.clients...)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
for i := range rClients {
|
||||
tc.clients[i].Credentials.Secret = rClients[i].Credentials.Secret
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
rClients, err := repo.Save(context.Background(), tc.clients...)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
for i := range rClients {
|
||||
tc.clients[i].Credentials.Secret = rClients[i].Credentials.Secret
|
||||
}
|
||||
assert.Equal(t, tc.clients, rClients, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.clients, rClients))
|
||||
}
|
||||
assert.Equal(t, tc.clients, rClients, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.clients, rClients))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "clients_domain_id_secret_key":
|
||||
return errors.NewRequestError("client key is not available"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+11
-10
@@ -4,7 +4,6 @@ package clients
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
smq "github.com/absmach/supermq"
|
||||
@@ -20,9 +19,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errRollbackRepo = errors.New("failed to rollback repo")
|
||||
errSetParentGroup = errors.New("client already have parent")
|
||||
errSetSameParentGroup = errors.New("client already assigned to the parent group")
|
||||
errRollbackRepo = errors.New("failed to rollback repo")
|
||||
errSetParentGroup = errors.NewRequestError("client already have parent")
|
||||
errSetSameParentGroup = errors.NewRequestError("client already assigned to the parent group")
|
||||
errParentGroupDomainID = errors.NewRequestError("parent group has invalid domain id")
|
||||
errParentGroupDisabled = errors.NewRequestError("parent group is not enabled")
|
||||
)
|
||||
var _ Service = (*service)(nil)
|
||||
|
||||
@@ -59,14 +60,14 @@ func (svc service) CreateClients(ctx context.Context, session authn.Session, cls
|
||||
if c.ID == "" {
|
||||
clientID, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
return []Client{}, []roles.RoleProvision{}, err
|
||||
return []Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
|
||||
}
|
||||
c.ID = clientID
|
||||
}
|
||||
if c.Credentials.Secret == "" {
|
||||
key, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
return []Client{}, []roles.RoleProvision{}, err
|
||||
return []Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
|
||||
}
|
||||
c.Credentials.Secret = key
|
||||
}
|
||||
@@ -260,10 +261,10 @@ func (svc service) SetParentGroup(ctx context.Context, session authn.Session, pa
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if resp.GetEntity().GetDomainId() != session.DomainID {
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s has invalid domain id", parentGroupID))
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, errParentGroupDomainID)
|
||||
}
|
||||
if resp.GetEntity().GetStatus() != uint32(EnabledStatus) {
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s is not in enabled state", parentGroupID))
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, errParentGroupDisabled)
|
||||
}
|
||||
|
||||
var pols []policies.Policy
|
||||
@@ -326,7 +327,7 @@ func (svc service) RemoveParentGroup(ctx context.Context, session authn.Session,
|
||||
cli := Client{ID: id, UpdatedBy: session.UserID, UpdatedAt: time.Now().UTC()}
|
||||
|
||||
if err := svc.repo.RemoveParentGroup(ctx, cli); err != nil {
|
||||
return err
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -388,7 +389,7 @@ func (svc service) changeClientStatus(ctx context.Context, session authn.Session
|
||||
return Client{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if dbClient.Status == client.Status {
|
||||
return Client{}, errors.ErrStatusAlreadyAssigned
|
||||
return Client{}, svcerr.ErrStatusAlreadyAssigned
|
||||
}
|
||||
|
||||
client.UpdatedBy = session.UserID
|
||||
|
||||
@@ -771,8 +771,8 @@ func TestEnable(t *testing.T) {
|
||||
client: enabledClient1,
|
||||
changeStatusResponse: enabledClient1,
|
||||
retrieveByIDResponse: enabledClient1,
|
||||
changeStatusErr: errors.ErrStatusAlreadyAssigned,
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "enable non-existing client",
|
||||
@@ -844,8 +844,8 @@ func TestDisable(t *testing.T) {
|
||||
client: disabledClient1,
|
||||
changeStatusResponse: clients.Client{},
|
||||
retrieveByIDResponse: disabledClient1,
|
||||
changeStatusErr: errors.ErrStatusAlreadyAssigned,
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "disable non-existing client",
|
||||
|
||||
@@ -21,7 +21,6 @@ const (
|
||||
domainIDKey = "domain_id"
|
||||
invitedByKey = "invited_by"
|
||||
roleIDKey = "role_id"
|
||||
roleNameKey = "role_name"
|
||||
stateKey = "state"
|
||||
)
|
||||
|
||||
@@ -31,7 +30,7 @@ func decodeCreateDomainRequest(_ context.Context, r *http.Request) (any, error)
|
||||
}
|
||||
req := createDomainReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -59,7 +58,7 @@ func decodeUpdateDomainRequest(_ context.Context, r *http.Request) (any, error)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -191,7 +190,7 @@ func decodeSendInvitationReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
var req sendInvitationReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -257,7 +256,7 @@ func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (any, error)
|
||||
|
||||
var req acceptInvitationReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -269,7 +268,7 @@ func decodeDeleteInvitationReq(_ context.Context, r *http.Request) (any, error)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -211,7 +211,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "register domain with service error",
|
||||
@@ -563,14 +563,14 @@ func TestListDomains(t *testing.T) {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "list domains with invalid dir",
|
||||
desc: "list domains with invalid dir",
|
||||
token: validToken,
|
||||
query: "dir= ",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
{
|
||||
desc: "list domains with duplicate dir",
|
||||
desc: "list domains with duplicate dir",
|
||||
token: validToken,
|
||||
query: "dir=asc&dir=asc",
|
||||
status: http.StatusBadRequest,
|
||||
@@ -585,7 +585,7 @@ func TestListDomains(t *testing.T) {
|
||||
Order: api.DefOrder,
|
||||
Dir: api.DefDir,
|
||||
},
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
listDomainsResp: domains.DomainsPage{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
@@ -652,10 +652,10 @@ func TestViewDomain(t *testing.T) {
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "view domain with invalid id",
|
||||
desc: "view domain with service error",
|
||||
token: validToken,
|
||||
domainID: invalid,
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
@@ -822,7 +822,7 @@ func TestUpdateDomain(t *testing.T) {
|
||||
},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update domain with invalid id",
|
||||
|
||||
+22
-30
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
@@ -31,12 +30,11 @@ const (
|
||||
rolesTableNamePrefix = "domains"
|
||||
entityTableName = "domains"
|
||||
entityIDColumnName = "id"
|
||||
|
||||
pgDuplicateErrCode = "23505"
|
||||
)
|
||||
|
||||
type domainRepo struct {
|
||||
db postgres.Database
|
||||
eh errors.Handler
|
||||
rolesPostgres.Repository
|
||||
}
|
||||
|
||||
@@ -44,8 +42,12 @@ type domainRepo struct {
|
||||
// implementation of Domain repository.
|
||||
func NewRepository(db postgres.Database) domains.Repository {
|
||||
rmsvcRepo := rolesPostgres.NewRepository(db, policies.DomainType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &domainRepo{
|
||||
db: db,
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
Repository: rmsvcRepo,
|
||||
}
|
||||
}
|
||||
@@ -62,7 +64,7 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
|
||||
if err != nil {
|
||||
return domains.Domain{}, handleSaveError(repoerr.ErrCreateEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
@@ -72,7 +74,7 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
|
||||
|
||||
dbd = dbDomain{}
|
||||
if err := row.StructScan(&dbd); err != nil {
|
||||
return domains.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
domain, err := toDomain(dbd)
|
||||
@@ -83,16 +85,6 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
func handleSaveError(wrapper, err error) error {
|
||||
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
|
||||
switch pqErr.ConstraintName {
|
||||
case "domains_route_key":
|
||||
return errors.ErrRouteNotAvailable
|
||||
}
|
||||
}
|
||||
return postgres.HandleError(wrapper, err)
|
||||
}
|
||||
|
||||
// RetrieveDomainByIDWithRoles retrieves Domain by its unique ID along with member roles.
|
||||
func (repo domainRepo) RetrieveDomainByIDWithRoles(ctx context.Context, id string, memberID string) (domains.Domain, error) {
|
||||
q := `
|
||||
@@ -176,14 +168,14 @@ func (repo domainRepo) RetrieveDomainByIDWithRoles(ctx context.Context, id strin
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbdp)
|
||||
if err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
dbd := dbDomain{}
|
||||
if rows.Next() {
|
||||
if err = rows.StructScan(&dbd); err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
domain, err := toDomain(dbd)
|
||||
@@ -207,14 +199,14 @@ func (repo domainRepo) RetrieveDomainByID(ctx context.Context, id string) (domai
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbdp)
|
||||
if err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
dbd := dbDomain{}
|
||||
if rows.Next() {
|
||||
if err = rows.StructScan(&dbd); err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
domain, err := toDomain(dbd)
|
||||
@@ -238,14 +230,14 @@ func (repo domainRepo) RetrieveDomainByRoute(ctx context.Context, route string)
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbdom)
|
||||
if err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
dbd := dbDomain{}
|
||||
if rows.Next() {
|
||||
if err = rows.StructScan(&dbd); err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
domain, err := toDomain(dbd)
|
||||
@@ -284,13 +276,13 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
doms, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
cq := "SELECT COUNT(*) FROM domains d"
|
||||
@@ -300,7 +292,7 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
return domains.DomainsPage{
|
||||
@@ -370,13 +362,13 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
doms, err = repo.processRows(rows)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,7 +385,7 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
|
||||
if err != nil {
|
||||
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
return domains.DomainsPage{
|
||||
@@ -451,7 +443,7 @@ func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.D
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
|
||||
if err != nil {
|
||||
return domains.Domain{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
@@ -461,7 +453,7 @@ func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.D
|
||||
|
||||
dbd = dbDomain{}
|
||||
if err := row.StructScan(&dbd); err != nil {
|
||||
return domains.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
domain, err := toDomain(dbd)
|
||||
@@ -478,7 +470,7 @@ func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error {
|
||||
|
||||
res, err := repo.db.ExecContext(ctx, q, id)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := res.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
|
||||
@@ -21,8 +21,9 @@ import (
|
||||
const invalid = "invalid"
|
||||
|
||||
var (
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
errDomainExists = errors.New("domain already exists")
|
||||
)
|
||||
|
||||
func TestSaveDomain(t *testing.T) {
|
||||
@@ -72,7 +73,7 @@ func TestSaveDomain(t *testing.T) {
|
||||
UpdatedBy: userID,
|
||||
Status: domains.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errDomainExists,
|
||||
},
|
||||
{
|
||||
desc: "add domain with empty ID",
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "domains_route_key":
|
||||
return errors.ErrRouteNotAvailable, true
|
||||
case "domains_pkey":
|
||||
return errors.NewRequestError("domain already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+10
-20
@@ -26,26 +26,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
secret = "secret"
|
||||
email = "test@example.com"
|
||||
id = "testID"
|
||||
groupName = "smqx"
|
||||
description = "Description"
|
||||
memberRelation = "member"
|
||||
authoritiesObj = "authorities"
|
||||
loginDuration = 30 * time.Minute
|
||||
refreshDuration = 24 * time.Hour
|
||||
invalidDuration = 7 * 24 * time.Hour
|
||||
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
|
||||
groupName = "smqx"
|
||||
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrExpiry = errors.New("session is expired")
|
||||
errAddPolicies = errors.New("failed to add policies")
|
||||
errRollbackRepo = errors.New("failed to rollback repo")
|
||||
inValid = "invalid"
|
||||
valid = "valid"
|
||||
domain = domains.Domain{
|
||||
ErrExpiry = errors.New("session is expired")
|
||||
errAddPolicies = errors.New("failed to add policies")
|
||||
inValid = "invalid"
|
||||
valid = "valid"
|
||||
domain = domains.Domain{
|
||||
ID: validID,
|
||||
Name: groupName,
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
@@ -173,7 +163,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
session: validSession,
|
||||
addPoliciesErr: errAddPolicies,
|
||||
deleteDomainErr: svcerr.ErrRemoveEntity,
|
||||
err: errRollbackRepo,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
{
|
||||
desc: "create domain with failed to add roles",
|
||||
@@ -194,7 +184,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
session: validSession,
|
||||
addRolesErr: errors.ErrMalformedEntity,
|
||||
deleteDomainErr: errors.ErrMalformedEntity,
|
||||
err: errRollbackRepo,
|
||||
err: errors.ErrMalformedEntity,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -206,7 +196,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
policyCall := policy.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr)
|
||||
policyCall1 := policy.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
|
||||
_, _, err := svc.CreateDomain(context.Background(), tc.session, tc.d)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
@@ -22,7 +22,7 @@ func DecodeGroupCreate(_ context.Context, r *http.Request) (any, error) {
|
||||
}
|
||||
var g groups.Group
|
||||
if err := json.NewDecoder(r.Body).Decode(&g); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
req := createGroupReq{
|
||||
Group: g,
|
||||
@@ -63,7 +63,7 @@ func DecodeGroupUpdate(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "groupID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func decodeUpdateGroupTags(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "groupID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -126,7 +126,7 @@ func decodeAddParentGroupRequest(_ context.Context, r *http.Request) (any, error
|
||||
id: chi.URLParam(r, "groupID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -146,7 +146,7 @@ func decodeAddChildrenGroupsRequest(_ context.Context, r *http.Request) (any, er
|
||||
id: chi.URLParam(r, "groupID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -159,7 +159,7 @@ func decodeRemoveChildrenGroupsRequest(_ context.Context, r *http.Request) (any,
|
||||
id: chi.URLParam(r, "groupID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -347,7 +347,7 @@ func TestDecodeGroupCreate(t *testing.T) {
|
||||
"Content-Type": {api.ContentType},
|
||||
},
|
||||
resp: nil,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -400,7 +400,7 @@ func TestDecodeGroupUpdate(t *testing.T) {
|
||||
"Content-Type": {api.ContentType},
|
||||
},
|
||||
resp: nil,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
|
||||
},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "create group with name that is too long",
|
||||
@@ -572,7 +572,7 @@ func TestUpdateGroupTagsEndpoint(t *testing.T) {
|
||||
contentType: contentType,
|
||||
data: fmt.Sprintf(`{"tags":["%s"}`, newTag),
|
||||
status: http.StatusBadRequest,
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update group with empty id",
|
||||
@@ -886,7 +886,7 @@ func TestListGroups(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list groups with limit",
|
||||
@@ -908,7 +908,7 @@ func TestListGroups(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "limit=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list groups with limit greater than max",
|
||||
@@ -916,7 +916,7 @@ func TestListGroups(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "list groups with name",
|
||||
@@ -968,7 +968,7 @@ func TestListGroups(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "status=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: svcerr.ErrInvalidStatus,
|
||||
},
|
||||
{
|
||||
desc: "list groups with duplicate status",
|
||||
@@ -1028,7 +1028,7 @@ func TestListGroups(t *testing.T) {
|
||||
token: validToken,
|
||||
query: "metadata=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list groups with duplicate metadata",
|
||||
@@ -1330,7 +1330,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: "level=invalid&dir=-1&tree=false",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "retrieve group hierarchy with invalid direction",
|
||||
@@ -1339,7 +1339,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: "level=1&dir=invalid&tree=false",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "retrieve group hierarchy with invalid tree",
|
||||
@@ -1348,7 +1348,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: "level=1&dir=-1&tree=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "retrieve group hierarchy with empty groupID",
|
||||
@@ -2116,7 +2116,7 @@ func TestListChildrenGroupsEndpoint(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: "limit=invalid&offset=0",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list children groups with invalid offset",
|
||||
@@ -2125,7 +2125,7 @@ func TestListChildrenGroupsEndpoint(t *testing.T) {
|
||||
domainID: validID,
|
||||
query: "limit=1&offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list children groups with empty id",
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
var errCyclicParentGroup = errors.NewRequestError("cyclic parent, group is parent of requested group")
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "groups_pkey":
|
||||
return errors.NewRequestError("group id already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+49
-45
@@ -41,6 +41,7 @@ var (
|
||||
|
||||
type groupRepository struct {
|
||||
db postgres.Database
|
||||
eh errors.Handler
|
||||
rolesPostgres.Repository
|
||||
}
|
||||
|
||||
@@ -48,9 +49,12 @@ type groupRepository struct {
|
||||
// repository.
|
||||
func New(db postgres.Database) groups.Repository {
|
||||
roleRepo := rolesPostgres.NewRepository(db, policies.GroupType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &groupRepository{
|
||||
db: db,
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
Repository: roleRepo,
|
||||
}
|
||||
}
|
||||
@@ -62,19 +66,19 @@ func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Gr
|
||||
}
|
||||
dbg, err := toDBGroup(g)
|
||||
if err != nil {
|
||||
return groups.Group{}, err
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
|
||||
if err != nil {
|
||||
return groups.Group{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
row.Next()
|
||||
dbg = dbGroup{}
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, err
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return toGroup(dbg)
|
||||
@@ -107,7 +111,7 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
@@ -116,7 +120,7 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
|
||||
}
|
||||
dbu = dbGroup{}
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return groups.Group{}, errors.Wrap(err, repoerr.ErrUpdateEntity)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return toGroup(dbu)
|
||||
}
|
||||
@@ -134,14 +138,14 @@ func (repo groupRepository) UpdateTags(ctx context.Context, group groups.Group)
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
|
||||
if err != nil {
|
||||
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbg = dbGroup{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return toGroup(dbg)
|
||||
@@ -160,7 +164,7 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group
|
||||
}
|
||||
row, err := repo.db.NamedQueryContext(ctx, qc, dbg)
|
||||
if err != nil {
|
||||
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
if ok := row.Next(); !ok {
|
||||
@@ -168,7 +172,7 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group
|
||||
}
|
||||
dbg = dbGroup{}
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, errors.Wrap(err, repoerr.ErrUpdateEntity)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return toGroup(dbg)
|
||||
@@ -184,7 +188,7 @@ func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (groups
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
@@ -193,7 +197,7 @@ func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (groups
|
||||
return groups.Group{}, repoerr.ErrNotFound
|
||||
}
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return toGroup(dbg)
|
||||
}
|
||||
@@ -341,17 +345,17 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
|
||||
}
|
||||
row, err := repo.db.NamedQueryContext(ctx, query, parameters)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbg := dbGroup{}
|
||||
if !row.Next() {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrNotFound, err)
|
||||
return groups.Group{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return toGroup(dbg)
|
||||
@@ -394,7 +398,7 @@ func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, u
|
||||
|
||||
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
|
||||
if err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
@@ -403,7 +407,7 @@ func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, u
|
||||
return groups.Group{}, repoerr.ErrNotFound
|
||||
}
|
||||
if err := row.StructScan(&dbg); err != nil {
|
||||
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return toGroup(dbg)
|
||||
}
|
||||
@@ -446,13 +450,13 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err = repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -465,7 +469,7 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
@@ -490,13 +494,13 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
|
||||
}
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
|
||||
@@ -508,7 +512,7 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
@@ -570,13 +574,13 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, query, parameters)
|
||||
if err != nil {
|
||||
return groups.HierarchyPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.HierarchyPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.HierarchyPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.HierarchyPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
return groups.HierarchyPage{HierarchyPageMeta: hm, Groups: items}, nil
|
||||
@@ -589,7 +593,7 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
|
||||
|
||||
tx, err := repo.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -602,13 +606,13 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
|
||||
pq := `SELECT id, path FROM groups WHERE id = $1 LIMIT 1;`
|
||||
rows, err := tx.Queryx(pq, parentGroupID)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
pGroups, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if len(pGroups) == 0 {
|
||||
return repoerr.ErrUpdateEntity
|
||||
@@ -628,7 +632,7 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
|
||||
for _, sPath := range sPaths {
|
||||
for _, cgid := range groupIDs {
|
||||
if sPath == cgid {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, fmt.Errorf("cyclic parent, group %s is parent of requested group %s", cgid, parentGroupID))
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, errCyclicParentGroup)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -645,12 +649,12 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
|
||||
|
||||
crows, err := tx.NamedQuery(query, params)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer crows.Close()
|
||||
cgroups, err := repo.processRows(crows)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
childrenPaths := []string{}
|
||||
@@ -666,11 +670,11 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
|
||||
WHERE path <@ ANY($2::ltree[]);`
|
||||
|
||||
if _, err := tx.Exec(query, pGroup.Path, childrenPaths); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -682,7 +686,7 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
|
||||
|
||||
tx, err := repo.db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -694,13 +698,13 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
|
||||
pq := `SELECT id, path FROM groups WHERE id = $1 LIMIT 1;`
|
||||
rows, err := tx.Queryx(pq, parentGroupID)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
pGroups, err := repo.processRows(rows)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if len(pGroups) == 0 {
|
||||
return repoerr.ErrUpdateEntity
|
||||
@@ -725,12 +729,12 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
|
||||
}
|
||||
crows, err := tx.NamedQuery(query, parameters)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer crows.Close()
|
||||
cgroups, err := repo.processRows(crows)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
childrenPaths := []string{}
|
||||
@@ -746,11 +750,11 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
|
||||
WHERE path <@ ANY($2::ltree[]);`
|
||||
|
||||
if _, err := tx.Exec(query, pGroup.Path, childrenPaths); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -764,7 +768,7 @@ func (repo groupRepository) UnassignAllChildrenGroups(ctx context.Context, id st
|
||||
|
||||
result, err := repo.db.NamedExecContext(ctx, query, dbGroup{ParentID: &id})
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -778,7 +782,7 @@ func (repo groupRepository) Delete(ctx context.Context, groupID string) error {
|
||||
|
||||
result, err := repo.db.ExecContext(ctx, q, groupID)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -922,13 +926,13 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
items, err = repo.processRows(rows)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -943,7 +947,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
|
||||
if err != nil {
|
||||
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
|
||||
page := groups.Page{PageMeta: pm}
|
||||
|
||||
@@ -61,6 +61,7 @@ var (
|
||||
"subgroup_remove_role_users",
|
||||
"subgroup_view_role_users",
|
||||
}
|
||||
errGroupExists = errors.NewRequestError("group id already exists")
|
||||
)
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
@@ -106,7 +107,7 @@ func TestSave(t *testing.T) {
|
||||
{
|
||||
desc: "add duplicate group",
|
||||
group: validGroup,
|
||||
err: repoerr.ErrConflict,
|
||||
err: errGroupExists,
|
||||
},
|
||||
{
|
||||
desc: "add group with parent",
|
||||
@@ -125,7 +126,7 @@ func TestSave(t *testing.T) {
|
||||
CreatedAt: validTimestamp,
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add group with invalid domain",
|
||||
@@ -138,7 +139,7 @@ func TestSave(t *testing.T) {
|
||||
CreatedAt: validTimestamp,
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add group with invalid parent",
|
||||
@@ -164,7 +165,7 @@ func TestSave(t *testing.T) {
|
||||
CreatedAt: validTimestamp,
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add group with invalid description",
|
||||
@@ -177,7 +178,7 @@ func TestSave(t *testing.T) {
|
||||
CreatedAt: validTimestamp,
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add group with invalid metadata",
|
||||
@@ -205,7 +206,7 @@ func TestSave(t *testing.T) {
|
||||
CreatedAt: validTimestamp,
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add group with duplicate name",
|
||||
|
||||
+13
-8
@@ -19,7 +19,12 @@ import (
|
||||
"github.com/absmach/supermq/pkg/roles"
|
||||
)
|
||||
|
||||
var ErrGroupIDs = errors.New("invalid group ids")
|
||||
var (
|
||||
ErrGroupIDs = errors.New("invalid group ids")
|
||||
errChangeGroupStatus = errors.NewServiceError("failed to change group status")
|
||||
errGroupHaveParent = errors.NewRequestError("group already have parent")
|
||||
errDifferentParent = errors.NewRequestError("groups have different parent")
|
||||
)
|
||||
|
||||
type service struct {
|
||||
repo Repository
|
||||
@@ -50,7 +55,7 @@ func NewService(repo Repository, policy policies.Service, idp supermq.IDProvider
|
||||
func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g Group) (retGr Group, retRps []roles.RoleProvision, retErr error) {
|
||||
groupID, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
return Group{}, []roles.RoleProvision{}, err
|
||||
return Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
if g.Status != EnabledStatus && g.Status != DisabledStatus {
|
||||
return Group{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus
|
||||
@@ -180,7 +185,7 @@ func (svc service) EnableGroup(ctx context.Context, session smqauthn.Session, id
|
||||
}
|
||||
group, err := svc.changeGroupStatus(ctx, session, group)
|
||||
if err != nil {
|
||||
return Group{}, err
|
||||
return Group{}, errors.Wrap(errChangeGroupStatus, err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
@@ -193,7 +198,7 @@ func (svc service) DisableGroup(ctx context.Context, session smqauthn.Session, i
|
||||
}
|
||||
group, err := svc.changeGroupStatus(ctx, session, group)
|
||||
if err != nil {
|
||||
return Group{}, err
|
||||
return Group{}, errors.Wrap(errChangeGroupStatus, err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
@@ -290,7 +295,7 @@ func (svc service) AddChildrenGroups(ctx context.Context, session smqauthn.Sessi
|
||||
|
||||
for _, childGroup := range childrenGroupsPage.Groups {
|
||||
if childGroup.Parent != "" {
|
||||
return errors.Wrap(svcerr.ErrConflict, fmt.Errorf("%s group already have parent", childGroup.ID))
|
||||
return errors.Wrap(svcerr.ErrConflict, errGroupHaveParent)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,7 +341,7 @@ func (svc service) RemoveChildrenGroups(ctx context.Context, session smqauthn.Se
|
||||
|
||||
for _, group := range childrenGroupsPage.Groups {
|
||||
if group.Parent != "" && group.Parent != parentGroupID {
|
||||
return errors.Wrap(svcerr.ErrConflict, fmt.Errorf("%s group doesn't have same parent", group.ID))
|
||||
return errors.Wrap(svcerr.ErrConflict, errDifferentParent)
|
||||
}
|
||||
pols = append(pols, policies.Policy{
|
||||
Domain: session.DomainID,
|
||||
@@ -440,7 +445,7 @@ func (svc service) DeleteGroup(ctx context.Context, session smqauthn.Session, id
|
||||
}
|
||||
|
||||
if err := svc.repo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -452,7 +457,7 @@ func (svc service) changeGroupStatus(ctx context.Context, session smqauthn.Sessi
|
||||
return Group{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if dbGroup.Status == group.Status {
|
||||
return Group{}, errors.ErrStatusAlreadyAssigned
|
||||
return Group{}, svcerr.ErrStatusAlreadyAssigned
|
||||
}
|
||||
|
||||
group.UpdatedBy = session.UserID
|
||||
|
||||
+10
-11
@@ -85,9 +85,8 @@ var (
|
||||
Status: groups.EnabledStatus,
|
||||
Children: children,
|
||||
}
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
errRollbackRoles = errors.New("failed to rollback roles")
|
||||
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -165,7 +164,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
group: validGroup,
|
||||
saveResp: groups.Group{},
|
||||
saveErr: errors.ErrMalformedEntity,
|
||||
err: errors.Wrap(svcerr.ErrCreateEntity, errors.ErrMalformedEntity),
|
||||
err: errors.ErrMalformedEntity,
|
||||
},
|
||||
{
|
||||
desc: " create group with failed to add policies",
|
||||
@@ -176,7 +175,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
Domain: validID,
|
||||
},
|
||||
addPoliciesErr: svcerr.ErrAuthorization,
|
||||
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrAuthorization)),
|
||||
err: svcerr.ErrAddPolicies,
|
||||
},
|
||||
{
|
||||
desc: " create group with failed to add policies and failed rollback",
|
||||
@@ -188,7 +187,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
},
|
||||
addPoliciesErr: svcerr.ErrAuthorization,
|
||||
deleteErr: svcerr.ErrRemoveEntity,
|
||||
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(apiutil.ErrRollbackTx, svcerr.ErrRemoveEntity)),
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
{
|
||||
desc: "create group with failed to add roles",
|
||||
@@ -199,7 +198,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
Domain: validID,
|
||||
},
|
||||
addRoleErr: svcerr.ErrCreateEntity,
|
||||
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrCreateEntity)),
|
||||
err: svcerr.ErrAddPolicies,
|
||||
},
|
||||
{
|
||||
desc: "create groups with failed to add roles and failed to delete policies",
|
||||
@@ -211,7 +210,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
},
|
||||
addRoleErr: svcerr.ErrCreateEntity,
|
||||
deletePoliciesErr: svcerr.ErrRemoveEntity,
|
||||
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, errors.Wrap(errRollbackRoles, svcerr.ErrRemoveEntity))),
|
||||
err: svcerr.ErrAddPolicies,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -223,7 +222,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
|
||||
repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr)
|
||||
got, _, err := svc.CreateGroup(context.Background(), validSession, tc.group)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("expected error %v but got %v", tc.err, err))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v but got %v", tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, got.ID)
|
||||
assert.NotEmpty(t, got.CreatedAt)
|
||||
@@ -418,7 +417,7 @@ func TestEnableGroup(t *testing.T) {
|
||||
retrieveResp: groups.Group{
|
||||
Status: groups.EnabledStatus,
|
||||
},
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "enable group with retrieve error",
|
||||
@@ -472,7 +471,7 @@ func TestDisableGroup(t *testing.T) {
|
||||
retrieveResp: groups.Group{
|
||||
Status: groups.DisabledStatus,
|
||||
},
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "disable group with retrieve error",
|
||||
|
||||
@@ -133,7 +133,9 @@ func TestAuthConnect(t *testing.T) {
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -450,7 +452,9 @@ func TestPublish(t *testing.T) {
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
|
||||
}
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
clientsCall.Unset()
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "journal_pkey":
|
||||
return errors.NewRequestError("journal entry already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
@@ -18,10 +18,17 @@ import (
|
||||
|
||||
type repository struct {
|
||||
db postgres.Database
|
||||
eh errors.Handler
|
||||
}
|
||||
|
||||
func NewRepository(db postgres.Database) journal.Repository {
|
||||
return &repository{db: db}
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &repository{
|
||||
db: db,
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *repository) Save(ctx context.Context, j journal.Journal) (err error) {
|
||||
@@ -41,11 +48,11 @@ func (repo *repository) Save(ctx context.Context, j journal.Journal) (err error)
|
||||
|
||||
dbJournal, err := toDBJournal(j)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
if _, err = repo.db.NamedExecContext(ctx, q, dbJournal); err != nil {
|
||||
return postgres.HandleError(repoerr.ErrCreateEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -68,7 +75,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
|
||||
rows, err := repo.db.NamedQueryContext(ctx, q, page)
|
||||
if err != nil {
|
||||
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -76,7 +83,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
for rows.Next() {
|
||||
var item dbJournal
|
||||
if err = rows.StructScan(&item); err != nil {
|
||||
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
j, err := toJournal(item)
|
||||
if err != nil {
|
||||
@@ -89,7 +96,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
|
||||
|
||||
total, err := postgres.Total(ctx, repo.db, tq, page)
|
||||
if err != nil {
|
||||
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
journalsPage := journal.JournalsPage{
|
||||
|
||||
@@ -87,7 +87,8 @@ var (
|
||||
"created_at": time.Now().Add(-time.Hour),
|
||||
"name": "group",
|
||||
}
|
||||
validTimeStamp = time.Now().UTC().Truncate(time.Millisecond)
|
||||
validTimeStamp = time.Now().UTC().Truncate(time.Millisecond)
|
||||
errJournalExists = errors.NewRequestError("journal entry already exists")
|
||||
)
|
||||
|
||||
func TestJournalSave(t *testing.T) {
|
||||
@@ -125,7 +126,7 @@ func TestJournalSave(t *testing.T) {
|
||||
Attributes: payload,
|
||||
Metadata: payload,
|
||||
},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errJournalExists,
|
||||
},
|
||||
{
|
||||
desc: "with massive journal metadata and attributes",
|
||||
|
||||
+14
-10
@@ -417,7 +417,7 @@ func TestPublish(t *testing.T) {
|
||||
session: &sessionClient,
|
||||
topic: malformedSubtopics,
|
||||
payload: payload,
|
||||
err: errors.Wrap(mqtt.ErrFailedPublish, errors.Wrap(messaging.ErrMalformedTopic, errors.Wrap(messaging.ErrMalformedSubtopic, errors.New("invalid URL escape \"%\"")))),
|
||||
err: errors.New("invalid URL escape \"%\""),
|
||||
},
|
||||
{
|
||||
desc: "publish with subtopic containing wrong character",
|
||||
@@ -457,15 +457,19 @@ func TestPublish(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ctx := context.TODO()
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
}
|
||||
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
err := handler.Publish(ctx, &tc.topic, &tc.payload)
|
||||
assert.Contains(t, logBuffer.String(), tc.logMsg)
|
||||
assert.Equal(t, tc.err, err)
|
||||
repoCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
if tc.session != nil {
|
||||
ctx = session.NewContext(ctx, tc.session)
|
||||
}
|
||||
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
err := handler.Publish(ctx, &tc.topic, &tc.payload)
|
||||
assert.Contains(t, logBuffer.String(), tc.logMsg)
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
|
||||
}
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type NestError interface {
|
||||
Error
|
||||
Embed(e error) error
|
||||
}
|
||||
|
||||
var _ NestError = (*customError)(nil)
|
||||
|
||||
func (e *customError) Embed(err error) error {
|
||||
if err == nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return &customError{
|
||||
msg: e.msg,
|
||||
err: fmt.Errorf("%w: %w", e.err, err),
|
||||
}
|
||||
}
|
||||
|
||||
type RequestError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*RequestError)(nil)
|
||||
|
||||
func NewRequestError(message string) NestError {
|
||||
return &RequestError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRequestErrorWithErr(message string, err error) NestError {
|
||||
return &RequestError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *RequestError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &RequestError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
type AuthNError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*AuthNError)(nil)
|
||||
|
||||
func NewAuthNError(message string) NestError {
|
||||
return &AuthNError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAuthNErrorWithErr(message string, err error) NestError {
|
||||
return &AuthNError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *AuthNError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &AuthNError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
var _ NestError = (*AuthZError)(nil)
|
||||
|
||||
type AuthZError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
func (e *AuthZError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &AuthZError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
func NewAuthZError(message string) NestError {
|
||||
return &AuthZError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewAuthZErrorWithErr(message string, err error) NestError {
|
||||
return &AuthZError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type InternalError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*InternalError)(nil)
|
||||
|
||||
func NewInternalError() error {
|
||||
return &InternalError{
|
||||
customError: customError{
|
||||
msg: "internal server error",
|
||||
err: errors.New("internal server error"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewInternalErrorWithErr(err error) NestError {
|
||||
return &InternalError{
|
||||
customError: customError{
|
||||
msg: "internal server error",
|
||||
err: fmt.Errorf("%w: %w", errors.New("internal server error"), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *InternalError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &InternalError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
type ServiceError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*ServiceError)(nil)
|
||||
|
||||
func NewServiceError(message string) NestError {
|
||||
return &ServiceError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewServiceErrorWithErr(message string, err error) NestError {
|
||||
return &ServiceError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ServiceError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &ServiceError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
type MediaTypeError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*MediaTypeError)(nil)
|
||||
|
||||
func NewMediaTypeError(message string) NestError {
|
||||
return &MediaTypeError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewMediaTypeErrorWithErr(message string, err error) NestError {
|
||||
return &MediaTypeError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *MediaTypeError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &MediaTypeError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
|
||||
type NotFoundError struct {
|
||||
customError
|
||||
}
|
||||
|
||||
var _ NestError = (*NotFoundError)(nil)
|
||||
|
||||
func NewNotFoundError(message string) NestError {
|
||||
return &NotFoundError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: errors.New(message),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewNotFoundErrorWithErr(message string, err error) NestError {
|
||||
return &NotFoundError{
|
||||
customError: customError{
|
||||
msg: message,
|
||||
err: fmt.Errorf("%w: %w", errors.New(message), err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *NotFoundError) Embed(err error) error {
|
||||
embedded := e.customError.Embed(err)
|
||||
return &NotFoundError{
|
||||
customError: *embedded.(*customError),
|
||||
}
|
||||
}
|
||||
+15
-31
@@ -5,6 +5,8 @@ package errors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error specifies an API that must be fullfiled by error type.
|
||||
@@ -16,7 +18,7 @@ type Error interface {
|
||||
Msg() string
|
||||
|
||||
// Err returns wrapped error.
|
||||
Err() Error
|
||||
Err() error
|
||||
|
||||
// MarshalJSON returns a marshaled error.
|
||||
MarshalJSON() ([]byte, error)
|
||||
@@ -27,14 +29,14 @@ var _ Error = (*customError)(nil)
|
||||
// customError represents a SuperMQ error.
|
||||
type customError struct {
|
||||
msg string
|
||||
err Error
|
||||
err error
|
||||
}
|
||||
|
||||
// New returns an Error that formats as the given text.
|
||||
func New(text string) Error {
|
||||
return &customError{
|
||||
msg: text,
|
||||
err: nil,
|
||||
err: errors.New(text),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,32 +47,25 @@ func (ce *customError) Error() string {
|
||||
if ce.err == nil {
|
||||
return ce.msg
|
||||
}
|
||||
return ce.msg + " : " + ce.err.Error()
|
||||
return ce.err.Error()
|
||||
}
|
||||
|
||||
func (ce *customError) Msg() string {
|
||||
return ce.msg
|
||||
}
|
||||
|
||||
func (ce *customError) Err() Error {
|
||||
func (ce *customError) Err() error {
|
||||
return ce.err
|
||||
}
|
||||
|
||||
func (ce *customError) MarshalJSON() ([]byte, error) {
|
||||
var val string
|
||||
if e := ce.Err(); e != nil {
|
||||
val = e.Msg()
|
||||
}
|
||||
return json.Marshal(&struct {
|
||||
Err string `json:"error"`
|
||||
Msg string `json:"message"`
|
||||
}{
|
||||
Err: val,
|
||||
Msg: ce.Msg(),
|
||||
})
|
||||
}
|
||||
|
||||
// Contains inspects if e2 error is contained in any layer of e1 error.
|
||||
func Contains(e1, e2 error) bool {
|
||||
if e1 == nil || e2 == nil {
|
||||
return e2 == e1
|
||||
@@ -82,7 +77,8 @@ func Contains(e1, e2 error) bool {
|
||||
}
|
||||
return Contains(ce.Err(), e2)
|
||||
}
|
||||
return e1.Error() == e2.Error()
|
||||
|
||||
return errors.Is(e1, e2) || e1.Error() == e2.Error()
|
||||
}
|
||||
|
||||
// Wrap returns an Error that wrap err with wrapper.
|
||||
@@ -90,30 +86,18 @@ func Wrap(wrapper, err error) error {
|
||||
if wrapper == nil || err == nil {
|
||||
return wrapper
|
||||
}
|
||||
if w, ok := wrapper.(Error); ok {
|
||||
return &customError{
|
||||
msg: w.Msg(),
|
||||
err: cast(err),
|
||||
}
|
||||
if ne, ok := err.(NestError); ok {
|
||||
return ne.Embed(wrapper)
|
||||
}
|
||||
if ce, ok := wrapper.(NestError); ok {
|
||||
return ce.Embed(err)
|
||||
}
|
||||
return &customError{
|
||||
msg: wrapper.Error(),
|
||||
err: cast(err),
|
||||
err: fmt.Errorf("%w: %w", wrapper, err),
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap returns the wrapper and the error by separating the Wrapper from the error.
|
||||
func Unwrap(err error) (error, error) {
|
||||
if ce, ok := err.(Error); ok {
|
||||
if ce.Err() == nil {
|
||||
return nil, New(ce.Msg())
|
||||
}
|
||||
return New(ce.Msg()), ce.Err()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func cast(err error) Error {
|
||||
if err == nil {
|
||||
return nil
|
||||
|
||||
@@ -34,35 +34,35 @@ func TestError(t *testing.T) {
|
||||
desc: "level 0 wrapped error",
|
||||
err: err0,
|
||||
msg: "0",
|
||||
bytes: []byte(`{"error":"","message":"0"}`),
|
||||
bytes: []byte(`{"message":"0"}`),
|
||||
bytesErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "level 1 wrapped error",
|
||||
err: wrap(1),
|
||||
msg: message(1),
|
||||
bytes: []byte(`{"error":"0","message":"1"}`),
|
||||
bytes: []byte(`{"message":"0"}`),
|
||||
bytesErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "level 2 wrapped error",
|
||||
err: wrap(2),
|
||||
msg: message(2),
|
||||
bytes: []byte(`{"error":"1","message":"2"}`),
|
||||
bytes: []byte(`{"message":"0"}`),
|
||||
bytesErr: nil,
|
||||
},
|
||||
{
|
||||
desc: fmt.Sprintf("level %d wrapped error", level),
|
||||
err: wrap(level),
|
||||
msg: message(level),
|
||||
bytes: []byte(`{"error":"9","message":"` + strconv.Itoa(level) + `"}`),
|
||||
bytes: []byte(`{"message":"0"}`),
|
||||
bytesErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "nil error",
|
||||
err: errors.New(""),
|
||||
msg: "",
|
||||
bytes: []byte(`{"error":"","message":""}`),
|
||||
bytes: []byte(`{"message":""}`),
|
||||
bytesErr: nil,
|
||||
},
|
||||
}
|
||||
@@ -129,9 +129,9 @@ func TestContains(t *testing.T) {
|
||||
contains: true,
|
||||
},
|
||||
{
|
||||
desc: fmt.Sprintf("level %d wrapped error contains", level),
|
||||
desc: fmt.Sprintf("level %d wrapped error contains err0", level),
|
||||
container: wrap(level),
|
||||
contained: errors.New(strconv.Itoa(level / 2)),
|
||||
contained: err0,
|
||||
contains: true,
|
||||
},
|
||||
{
|
||||
@@ -276,66 +276,6 @@ func TestWrap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnwrap(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
err error
|
||||
wrapper error
|
||||
wrapped error
|
||||
}{
|
||||
{
|
||||
desc: "err 1 wraped err 2",
|
||||
err: errors.Wrap(err1, err2),
|
||||
wrapper: err1,
|
||||
wrapped: err2,
|
||||
},
|
||||
{
|
||||
desc: "err2 wraps err1 wraps err0",
|
||||
err: errors.Wrap(err2, errors.Wrap(err1, err0)),
|
||||
wrapper: err2,
|
||||
wrapped: errors.Wrap(err1, err0),
|
||||
},
|
||||
{
|
||||
desc: "nil wraps nil",
|
||||
err: errors.Wrap(nil, nil),
|
||||
wrapper: nil,
|
||||
wrapped: nil,
|
||||
},
|
||||
{
|
||||
desc: "err0 wraps nil",
|
||||
err: errors.Wrap(err0, nil),
|
||||
wrapper: nil,
|
||||
wrapped: err0,
|
||||
},
|
||||
{
|
||||
desc: "nil wraps err0",
|
||||
err: errors.Wrap(nil, err0),
|
||||
wrapper: nil,
|
||||
wrapped: nil,
|
||||
},
|
||||
{
|
||||
desc: "nil wraps native error",
|
||||
err: errors.Wrap(nil, nat),
|
||||
wrapper: nil,
|
||||
wrapped: nil,
|
||||
},
|
||||
{
|
||||
desc: "native error wraps nil",
|
||||
err: errors.Wrap(nat, nil),
|
||||
wrapper: nil,
|
||||
wrapped: nat,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
wrapper, wrapped := errors.Unwrap(c.err)
|
||||
assert.Equal(t, c.wrapper, wrapper)
|
||||
assert.Equal(t, c.wrapped, wrapped)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func wrap(level int) error {
|
||||
if level == 0 {
|
||||
return errors.New(strconv.Itoa(level))
|
||||
@@ -344,9 +284,10 @@ func wrap(level int) error {
|
||||
}
|
||||
|
||||
// message generates error message of wrap() generated wrapper error.
|
||||
// The error message format is now "innermost: ... : outermost" due to fmt.Errorf wrapping.
|
||||
func message(level int) string {
|
||||
if level == 0 {
|
||||
return "0"
|
||||
}
|
||||
return strconv.Itoa(level) + " : " + message(level-1)
|
||||
return message(level-1) + ": " + strconv.Itoa(level)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package errors
|
||||
|
||||
type Mapper interface {
|
||||
GetError(key string) (error, bool)
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
HandleError(wrapper, err error) error
|
||||
}
|
||||
|
||||
type HandlerOption func(*Handler)
|
||||
@@ -11,7 +11,7 @@ var (
|
||||
ErrMalformedEntity = errors.New("malformed entity specification")
|
||||
|
||||
// ErrNotFound indicates a non-existent entity request.
|
||||
ErrNotFound = errors.New("entity not found")
|
||||
ErrNotFound = errors.NewNotFoundError("entity not found")
|
||||
|
||||
// ErrConflict indicates that entity already exists.
|
||||
ErrConflict = errors.New("entity already exists")
|
||||
@@ -39,4 +39,13 @@ var (
|
||||
|
||||
// ErrMissingNames indicates missing first and last names.
|
||||
ErrMissingNames = errors.New("missing first or last name")
|
||||
|
||||
// ErrMarshalBDEntity indicates a failure to marshal a database entity.
|
||||
ErrMarshalBDEntity = errors.New("failed to marshal db entity")
|
||||
|
||||
// ErrUnmarshalBDEntity indicates a failure to unmarshal a database entity.
|
||||
ErrUnmarshalBDEntity = errors.New("failed to unmarshal db entity")
|
||||
|
||||
// ErrParseQueryParams indicates a failure to parse query parameters.
|
||||
ErrParseQueryParams = errors.NewRequestError("failed to parse query parameters")
|
||||
)
|
||||
|
||||
@@ -43,8 +43,11 @@ func TestNewSDKError(t *testing.T) {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
sdk := errors.NewSDKError(c.err)
|
||||
if c.err != nil {
|
||||
assert.Equal(t, sdk.StatusCode(), 0)
|
||||
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(0), c.err.Error()))
|
||||
assert.NotNil(t, sdk)
|
||||
assert.Equal(t, 0, sdk.StatusCode())
|
||||
assert.Equal(t, fmt.Sprintf("Status: %s: %s", http.StatusText(0), c.err.Error()), sdk.Error())
|
||||
} else {
|
||||
assert.Nil(t, sdk)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -102,8 +105,11 @@ func TestNewSDKErrorWithStatus(t *testing.T) {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
sdk := errors.NewSDKErrorWithStatus(c.err, c.sc)
|
||||
if c.err != nil {
|
||||
assert.Equal(t, sdk.StatusCode(), c.sc)
|
||||
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(c.sc), c.err.Error()))
|
||||
assert.NotNil(t, sdk)
|
||||
assert.Equal(t, c.sc, sdk.StatusCode())
|
||||
assert.Equal(t, fmt.Sprintf("Status: %s: %s", http.StatusText(c.sc), c.err.Error()), sdk.Error())
|
||||
} else {
|
||||
assert.Nil(t, sdk)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -196,10 +202,13 @@ func TestCheckError(t *testing.T) {
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
sdk := errors.CheckError(c.resp, c.codes...)
|
||||
assert.Equal(t, sdk, c.err)
|
||||
if c.err != nil {
|
||||
assert.Equal(t, sdk, c.err)
|
||||
assert.Equal(t, sdk.StatusCode(), c.resp.StatusCode)
|
||||
assert.NotNil(t, sdk)
|
||||
assert.Equal(t, c.err.StatusCode(), sdk.StatusCode())
|
||||
assert.Equal(t, c.err.Error(), sdk.Error())
|
||||
assert.Equal(t, c.resp.StatusCode, sdk.StatusCode())
|
||||
} else {
|
||||
assert.Nil(t, sdk)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+70
-55
@@ -8,101 +8,116 @@ import "github.com/absmach/supermq/pkg/errors"
|
||||
// Wrapper for Service errors.
|
||||
var (
|
||||
// ErrAuthentication indicates failure occurred while authenticating the entity.
|
||||
ErrAuthentication = errors.New("failed to perform authentication over the entity")
|
||||
|
||||
// ErrAuthorization indicates failure occurred while authorizing the entity.
|
||||
ErrAuthorization = errors.New("failed to perform authorization over the entity")
|
||||
|
||||
// ErrDomainAuthorization indicates failure occurred while authorizing the domain.
|
||||
ErrDomainAuthorization = errors.New("failed to perform authorization over the domain")
|
||||
ErrAuthentication = errors.NewAuthNError("failed to perform authentication over the entity")
|
||||
|
||||
// ErrLogin indicates wrong login credentials.
|
||||
ErrLogin = errors.New("invalid credentials")
|
||||
ErrLogin = errors.NewAuthNError("invalid credentials")
|
||||
|
||||
// ErrAuthorization indicates failure occurred while authorizing the entity.
|
||||
ErrAuthorization = errors.NewAuthZError("failed to perform authorization over the entity")
|
||||
|
||||
// ErrDomainAuthorization indicates failure occurred while authorizing the domain.
|
||||
ErrDomainAuthorization = errors.NewAuthZError("failed to perform authorization over the domain")
|
||||
|
||||
// ErrUnauthorizedPAT indicates failure occurred while authorizing PAT.
|
||||
ErrUnauthorizedPAT = errors.NewAuthZError("failed to authorize PAT")
|
||||
|
||||
// ErrSuperAdminAction indicates that the user is not a super admin.
|
||||
ErrSuperAdminAction = errors.NewAuthZError("not authorized to perform admin action")
|
||||
|
||||
// ErrCreateEntity indicates error in creating entity or entities.
|
||||
ErrCreateEntity = errors.NewServiceError("failed to create entity")
|
||||
|
||||
// ErrRemoveEntity indicates error in removing entity.
|
||||
ErrRemoveEntity = errors.NewServiceError("failed to remove entity")
|
||||
|
||||
// ErrViewEntity indicates error in viewing entity or entities.
|
||||
ErrViewEntity = errors.NewServiceError("view entity failed")
|
||||
|
||||
// ErrUpdateEntity indicates error in updating entity or entities.
|
||||
ErrUpdateEntity = errors.NewServiceError("update entity failed")
|
||||
|
||||
// ErrAddPolicies indicates error in adding policies.
|
||||
ErrAddPolicies = errors.NewServiceError("failed to add policies")
|
||||
|
||||
// ErrUserAlreadyVerified indicates user is already verified.
|
||||
ErrUserAlreadyVerified = errors.NewServiceError("user already verified")
|
||||
|
||||
// ErrInvalidUserVerification indicates user verification is invalid.
|
||||
ErrInvalidUserVerification = errors.NewServiceError("invalid verification")
|
||||
|
||||
// ErrIssueProviderID indicates failure to issue unique ID from ID provider.
|
||||
ErrIssueProviderID = errors.NewServiceError("failed to issue unique ID from id provider")
|
||||
|
||||
// ErrHashPassword indicates failure to hash password.
|
||||
ErrHashPassword = errors.NewServiceError("failed to hash password")
|
||||
|
||||
// ErrStatusAlreadyAssigned indicates that the client or group has already been assigned the status.
|
||||
ErrStatusAlreadyAssigned = errors.NewServiceError("status already assigned")
|
||||
|
||||
// ErrDeletePolicies indicates error in removing policies.
|
||||
ErrDeletePolicies = errors.NewServiceError("failed to remove policies")
|
||||
|
||||
// ErrMissingUsername indicates that the user's names are missing.
|
||||
ErrMissingUsername = errors.NewRequestError("missing usernames")
|
||||
|
||||
// ErrInvalidStatus indicates an invalid status.
|
||||
ErrInvalidStatus = errors.NewRequestError("invalid status")
|
||||
|
||||
// ErrInvalidRole indicates that an invalid role.
|
||||
ErrInvalidRole = errors.NewRequestError("invalid client role")
|
||||
|
||||
// ErrMalformedEntity indicates a malformed entity specification.
|
||||
ErrMalformedEntity = errors.New("malformed entity specification")
|
||||
|
||||
// ErrNotFound indicates a non-existent entity request.
|
||||
ErrNotFound = errors.New("entity not found")
|
||||
ErrNotFound = errors.NewNotFoundError("entity not found")
|
||||
|
||||
// ErrConflict indicates that entity already exists.
|
||||
ErrConflict = errors.New("entity already exists")
|
||||
|
||||
// ErrCreateEntity indicates error in creating entity or entities.
|
||||
ErrCreateEntity = errors.New("failed to create entity")
|
||||
|
||||
// ErrRemoveEntity indicates error in removing entity.
|
||||
ErrRemoveEntity = errors.New("failed to remove entity")
|
||||
|
||||
// ErrViewEntity indicates error in viewing entity or entities.
|
||||
ErrViewEntity = errors.New("view entity failed")
|
||||
|
||||
// ErrUpdateEntity indicates error in updating entity or entities.
|
||||
ErrUpdateEntity = errors.New("update entity failed")
|
||||
|
||||
// ErrInvalidStatus indicates an invalid status.
|
||||
ErrInvalidStatus = errors.New("invalid status")
|
||||
|
||||
// ErrInvalidRole indicates that an invalid role.
|
||||
ErrInvalidRole = errors.New("invalid client role")
|
||||
ErrConflict = errors.NewRequestError("entity already exists")
|
||||
|
||||
// ErrInvalidPolicy indicates that an invalid policy.
|
||||
ErrInvalidPolicy = errors.New("invalid policy")
|
||||
|
||||
// ErrEnableClient indicates error in enabling client.
|
||||
ErrEnableClient = errors.New("failed to enable client")
|
||||
ErrEnableClient = errors.NewServiceError("failed to enable client")
|
||||
|
||||
// ErrDisableClient indicates error in disabling client.
|
||||
ErrDisableClient = errors.New("failed to disable client")
|
||||
|
||||
// ErrAddPolicies indicates error in adding policies.
|
||||
ErrAddPolicies = errors.New("failed to add policies")
|
||||
|
||||
// ErrDeletePolicies indicates error in removing policies.
|
||||
ErrDeletePolicies = errors.New("failed to remove policies")
|
||||
ErrDisableClient = errors.NewServiceError("failed to disable client")
|
||||
|
||||
// ErrSearch indicates error in searching clients.
|
||||
ErrSearch = errors.New("failed to search clients")
|
||||
|
||||
// ErrInvitationAlreadyRejected indicates that the invitation is already rejected.
|
||||
ErrInvitationAlreadyRejected = errors.New("invitation already rejected")
|
||||
ErrInvitationAlreadyRejected = errors.NewRequestError("invitation already rejected")
|
||||
|
||||
// ErrInvitationAlreadyAccepted indicates that the invitation is already accepted.
|
||||
ErrInvitationAlreadyAccepted = errors.New("invitation already accepted")
|
||||
ErrInvitationAlreadyAccepted = errors.NewRequestError("invitation already accepted")
|
||||
|
||||
// ErrParentGroupAuthorization indicates failure occurred while authorizing the parent group.
|
||||
ErrParentGroupAuthorization = errors.New("failed to authorize parent group")
|
||||
|
||||
// ErrMissingUsername indicates that the user's names are missing.
|
||||
ErrMissingUsername = errors.New("missing usernames")
|
||||
|
||||
// ErrEnableUser indicates error in enabling user.
|
||||
ErrEnableUser = errors.New("failed to enable user")
|
||||
ErrEnableUser = errors.NewServiceError("failed to enable user")
|
||||
|
||||
// ErrDisableUser indicates error in disabling user.
|
||||
ErrDisableUser = errors.New("failed to disable user")
|
||||
ErrDisableUser = errors.NewServiceError("failed to disable user")
|
||||
|
||||
// ErrRollbackRepo indicates a failure to rollback repository.
|
||||
ErrRollbackRepo = errors.New("failed to rollback repo")
|
||||
|
||||
// ErrUnauthorizedPAT indicates failure occurred while authorizing PAT.
|
||||
ErrUnauthorizedPAT = errors.New("failed to authorize PAT")
|
||||
|
||||
// ErrRetainOneMember indicates that at least one owner must be retained in the entity.
|
||||
ErrRetainOneMember = errors.New("must retain at least one member")
|
||||
|
||||
// ErrSuperAdminAction indicates that the user is not a super admin.
|
||||
ErrSuperAdminAction = errors.New("not authorized to perform admin action")
|
||||
|
||||
// ErrUserAlreadyVerified indicates user is already verified.
|
||||
ErrUserAlreadyVerified = errors.New("user already verified")
|
||||
|
||||
// ErrInvalidUserVerification indicates user verification is invalid.
|
||||
ErrInvalidUserVerification = errors.New("invalid verification")
|
||||
|
||||
// ErrUserVerificationExpired indicates user verification is expired.
|
||||
ErrUserVerificationExpired = errors.New("verification expired, please generate new verification")
|
||||
|
||||
// ErrRegisterUser indicates error in register a user.
|
||||
ErrRegisterUser = errors.New("failed to register user")
|
||||
|
||||
// ErrExternalAuthProviderCouldNotUpdate indicates that users authenticated via external provider cannot update their account details directly.
|
||||
ErrExternalAuthProviderCouldNotUpdate = errors.New("account details can only be updated through your authentication provider's settings")
|
||||
|
||||
// ErrFailedToSaveEntityDB indicates failure to save entity to database.
|
||||
ErrFailedToSaveEntityDB = errors.New("failed to save entity to database")
|
||||
)
|
||||
|
||||
+3
-6
@@ -5,7 +5,7 @@ package errors
|
||||
|
||||
var (
|
||||
// ErrMalformedEntity indicates a malformed entity specification.
|
||||
ErrMalformedEntity = New("malformed entity specification")
|
||||
ErrMalformedEntity = NewRequestError("malformed entity specification")
|
||||
|
||||
// ErrUnsupportedContentType indicates invalid content type.
|
||||
ErrUnsupportedContentType = New("invalid content type")
|
||||
@@ -16,9 +16,6 @@ var (
|
||||
// ErrEmptyPath indicates empty file path.
|
||||
ErrEmptyPath = New("empty file path")
|
||||
|
||||
// ErrStatusAlreadyAssigned indicated that the client or group has already been assigned the status.
|
||||
ErrStatusAlreadyAssigned = New("status already assigned")
|
||||
|
||||
// ErrRollbackTx indicates failed to rollback transaction.
|
||||
ErrRollbackTx = New("failed to rollback transaction")
|
||||
|
||||
@@ -35,7 +32,7 @@ var (
|
||||
ErrMissingMember = New("member id is not found")
|
||||
|
||||
// ErrEmailAlreadyExists indicates that the email id already exists.
|
||||
ErrEmailAlreadyExists = New("email id already exists")
|
||||
ErrEmailAlreadyExists = New("email id already registered")
|
||||
|
||||
// ErrUsernameNotAvailable indicates that the username is not available.
|
||||
ErrUsernameNotAvailable = New("username not available")
|
||||
@@ -50,5 +47,5 @@ var (
|
||||
ErrTryAgain = New("Something went wrong, please try again")
|
||||
|
||||
// ErrRouteNotAvailable indicates that the username is not available.
|
||||
ErrRouteNotAvailable = New("route not available")
|
||||
ErrRouteNotAvailable = NewRequestError("route not available")
|
||||
)
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var _ errors.Handler = (*errHandler)(nil)
|
||||
|
||||
type errHandler struct {
|
||||
duplicateErrors errors.Mapper
|
||||
}
|
||||
|
||||
func WithDuplicateErrors(mapper errors.Mapper) errors.HandlerOption {
|
||||
return func(eh *errors.Handler) {
|
||||
if h, ok := (*eh).(*errHandler); ok {
|
||||
h.duplicateErrors = mapper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorHandler(opts ...errors.HandlerOption) errors.Handler {
|
||||
var eh errors.Handler = &errHandler{}
|
||||
for _, opt := range opts {
|
||||
opt(&eh)
|
||||
}
|
||||
return eh
|
||||
}
|
||||
|
||||
// Handle handles the error.
|
||||
func (eh errHandler) HandleError(wrapper, err error) error {
|
||||
pqErr, ok := err.(*pgconn.PgError)
|
||||
if ok {
|
||||
switch pqErr.Code {
|
||||
case errDuplicate:
|
||||
if eh.duplicateErrors != nil {
|
||||
if knownErr, ok := eh.duplicateErrors.GetError(pqErr.ConstraintName); ok {
|
||||
return errors.Wrap(wrapper, knownErr)
|
||||
}
|
||||
}
|
||||
return errors.Wrap(wrapper, err)
|
||||
case errInvalid, errInvalidChar, errTruncation, errUntranslatable:
|
||||
return errors.Wrap(wrapper, err)
|
||||
case errFK:
|
||||
return errors.Wrap(wrapper, err)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Wrap(wrapper, err)
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (d Decoder) DecodeCreateRole(_ context.Context, r *http.Request) (any, erro
|
||||
entityID: chi.URLParam(r, d.entityIDTemplate),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -135,7 +135,7 @@ func (d Decoder) DecodeRemoveEntityMembers(_ context.Context, r *http.Request) (
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -159,7 +159,7 @@ func (d Decoder) DecodeUpdateRole(_ context.Context, r *http.Request) (any, erro
|
||||
roleID: chi.URLParam(r, "roleID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -190,7 +190,7 @@ func (d Decoder) DecodeAddRoleActions(_ context.Context, r *http.Request) (any,
|
||||
roleID: chi.URLParam(r, "roleID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -214,7 +214,7 @@ func (d Decoder) DecodeDeleteRoleActions(_ context.Context, r *http.Request) (an
|
||||
roleID: chi.URLParam(r, "roleID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -238,7 +238,7 @@ func (d Decoder) DecodeAddRoleMembers(_ context.Context, r *http.Request) (any,
|
||||
roleID: chi.URLParam(r, "roleID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (d Decoder) DecodeDeleteRoleMembers(_ context.Context, r *http.Request) (an
|
||||
roleID: chi.URLParam(r, "roleID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
+32
-32
@@ -132,7 +132,7 @@ func TestCreateChannel(t *testing.T) {
|
||||
svcRes: []channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create channel with parent group",
|
||||
@@ -214,7 +214,7 @@ func TestCreateChannel(t *testing.T) {
|
||||
svcRes: []channels.Channel{iChannel},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -316,7 +316,7 @@ func TestCreateChannels(t *testing.T) {
|
||||
svcRes: []channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: []sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create channels with service response that can't be unmarshalled",
|
||||
@@ -334,7 +334,7 @@ func TestCreateChannels(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: []sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -493,7 +493,7 @@ func TestListChannels(t *testing.T) {
|
||||
svcRes: channels.ChannelsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list channels with level",
|
||||
@@ -571,7 +571,7 @@ func TestListChannels(t *testing.T) {
|
||||
svcRes: channels.ChannelsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list channels with service response that can't be unmarshalled",
|
||||
@@ -599,7 +599,7 @@ func TestListChannels(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ChannelsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -709,9 +709,9 @@ func TestViewChannel(t *testing.T) {
|
||||
withRoles: false,
|
||||
channelID: wrongID,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
|
||||
},
|
||||
{
|
||||
desc: "view channel with empty channel id",
|
||||
@@ -738,7 +738,7 @@ func TestViewChannel(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -971,7 +971,7 @@ func TestUpdateChannel(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update channel that can't be marshalled",
|
||||
@@ -988,7 +988,7 @@ func TestUpdateChannel(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update channel with service response that can't be unmarshalled",
|
||||
@@ -1010,7 +1010,7 @@ func TestUpdateChannel(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
{
|
||||
desc: "update channel with empty channel id",
|
||||
@@ -1157,7 +1157,7 @@ func TestUpdateChannelTags(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update channel tags with a response that can't be unmarshalled",
|
||||
@@ -1174,7 +1174,7 @@ func TestUpdateChannelTags(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1266,7 +1266,7 @@ func TestEnableChannel(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "enable channel with service response that can't be unmarshalled",
|
||||
@@ -1281,7 +1281,7 @@ func TestEnableChannel(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1376,7 +1376,7 @@ func TestDisableChannel(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disable channel with service response that can't be unmarshalled",
|
||||
@@ -1391,7 +1391,7 @@ func TestDisableChannel(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1572,7 +1572,7 @@ func TestConnect(t *testing.T) {
|
||||
Types: []string{"Publish", "Subscribe"},
|
||||
},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "connect with empty client id",
|
||||
@@ -1584,7 +1584,7 @@ func TestConnect(t *testing.T) {
|
||||
Types: []string{"Publish", "Subscribe"},
|
||||
},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1679,7 +1679,7 @@ func TestDisconnect(t *testing.T) {
|
||||
Types: []string{"Publish", "Subscribe"},
|
||||
},
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disconnect with empty channel id",
|
||||
@@ -1691,7 +1691,7 @@ func TestDisconnect(t *testing.T) {
|
||||
Types: []string{"Publish", "Subscribe"},
|
||||
},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disconnect with empty client id",
|
||||
@@ -1703,7 +1703,7 @@ func TestDisconnect(t *testing.T) {
|
||||
Types: []string{"Publish", "Subscribe"},
|
||||
},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1802,7 +1802,7 @@ func TestConnectClients(t *testing.T) {
|
||||
clientID: clientID,
|
||||
connType: "Publish",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "connect with empty client id",
|
||||
@@ -1812,7 +1812,7 @@ func TestConnectClients(t *testing.T) {
|
||||
clientID: "",
|
||||
connType: "Publish",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1895,7 +1895,7 @@ func TestDisconnectClients(t *testing.T) {
|
||||
channelID: wrongID,
|
||||
clientID: clientID,
|
||||
connType: "Publish",
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disconnect with empty channel id",
|
||||
@@ -1905,7 +1905,7 @@ func TestDisconnectClients(t *testing.T) {
|
||||
clientID: clientID,
|
||||
connType: "Publish",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disconnect with empty client id",
|
||||
@@ -1915,7 +1915,7 @@ func TestDisconnectClients(t *testing.T) {
|
||||
clientID: "",
|
||||
connType: "Publish",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -2002,7 +2002,7 @@ func TestSetChannelParent(t *testing.T) {
|
||||
channelID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "set channel parent with empty parent id",
|
||||
@@ -2011,7 +2011,7 @@ func TestSetChannelParent(t *testing.T) {
|
||||
channelID: channel.ID,
|
||||
parentID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingParentGroupID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingParentGroupID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2097,7 +2097,7 @@ func TestRemoveChannelParent(t *testing.T) {
|
||||
channelID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+51
-51
@@ -136,7 +136,7 @@ func TestCreateClient(t *testing.T) {
|
||||
svcRes: []clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create a client with invalid id",
|
||||
@@ -154,7 +154,7 @@ func TestCreateClient(t *testing.T) {
|
||||
svcRes: []clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create a client with a request that can't be marshalled",
|
||||
@@ -170,7 +170,7 @@ func TestCreateClient(t *testing.T) {
|
||||
svcRes: []clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create a client with a response that can't be unmarshalled",
|
||||
@@ -188,7 +188,7 @@ func TestCreateClient(t *testing.T) {
|
||||
}},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -281,7 +281,7 @@ func TestCreateClients(t *testing.T) {
|
||||
svcRes: []clients.Client{},
|
||||
svcErr: nil,
|
||||
response: []sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create new clients with a response that can't be unmarshalled",
|
||||
@@ -299,7 +299,7 @@ func TestCreateClients(t *testing.T) {
|
||||
}},
|
||||
svcErr: nil,
|
||||
response: []sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -422,7 +422,7 @@ func TestListClients(t *testing.T) {
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list all clients with name size greater than max",
|
||||
@@ -441,7 +441,7 @@ func TestListClients(t *testing.T) {
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list all clients with status",
|
||||
@@ -532,7 +532,7 @@ func TestListClients(t *testing.T) {
|
||||
svcRes: clients.ClientsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list all clients with response that can't be unmarshalled",
|
||||
@@ -566,7 +566,7 @@ func TestListClients(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.ClientsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -670,9 +670,9 @@ func TestViewClient(t *testing.T) {
|
||||
withRoles: false,
|
||||
clientID: wrongID,
|
||||
svcRes: clients.Client{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
|
||||
},
|
||||
{
|
||||
desc: "view client with empty client id",
|
||||
@@ -701,7 +701,7 @@ func TestViewClient(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -852,7 +852,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update client with a response that can't be unmarshalled",
|
||||
@@ -870,7 +870,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -988,7 +988,7 @@ func TestUpdateClientTags(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update client tags with a request that can't be marshalled",
|
||||
@@ -1004,7 +1004,7 @@ func TestUpdateClientTags(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update client tags with a response that can't be unmarshalled",
|
||||
@@ -1022,7 +1022,7 @@ func TestUpdateClientTags(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1125,7 +1125,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update client with empty new secret",
|
||||
@@ -1136,7 +1136,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingSecret), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingSecret, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update client secret with a response that can't be unmarshalled",
|
||||
@@ -1154,7 +1154,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1240,7 +1240,7 @@ func TestEnableClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "enable client with a response that can't be unmarshalled",
|
||||
@@ -1257,7 +1257,7 @@ func TestEnableClient(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1333,7 +1333,7 @@ func TestDisableClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: svcerr.ErrDisableClient,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrDisableClient, http.StatusInternalServerError),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrDisableClient, http.StatusUnprocessableEntity),
|
||||
},
|
||||
{
|
||||
desc: "disable client with empty client id",
|
||||
@@ -1343,7 +1343,7 @@ func TestDisableClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disable client with a response that can't be unmarshalled",
|
||||
@@ -1360,7 +1360,7 @@ func TestDisableClient(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1527,7 +1527,7 @@ func TestSetClientParent(t *testing.T) {
|
||||
clientID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "set client parent with empty parent id",
|
||||
@@ -1536,7 +1536,7 @@ func TestSetClientParent(t *testing.T) {
|
||||
clientID: clientID,
|
||||
parentID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingParentGroupID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingParentGroupID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1622,7 +1622,7 @@ func TestRemoveClientParent(t *testing.T) {
|
||||
clientID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1741,7 +1741,7 @@ func TestCreateClientRole(t *testing.T) {
|
||||
svcRes: roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create client role with empty role name",
|
||||
@@ -1756,7 +1756,7 @@ func TestCreateClientRole(t *testing.T) {
|
||||
svcRes: roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1887,7 +1887,7 @@ func TestListClientRoles(t *testing.T) {
|
||||
svcRes: roles.RolePage{},
|
||||
svcErr: nil,
|
||||
response: sdk.RolesPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1993,7 +1993,7 @@ func TestViewClientRole(t *testing.T) {
|
||||
svcRes: roles.Role{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "view client role with invalid role id",
|
||||
@@ -2122,7 +2122,7 @@ func TestUpdateClientRole(t *testing.T) {
|
||||
svcRes: roles.Role{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2209,7 +2209,7 @@ func TestDeleteClientRole(t *testing.T) {
|
||||
domainID: domainID,
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "delete client role with invalid role id",
|
||||
@@ -2319,7 +2319,7 @@ func TestAddClientRoleActions(t *testing.T) {
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "add client role actions with invalid role id",
|
||||
@@ -2341,7 +2341,7 @@ func TestAddClientRoleActions(t *testing.T) {
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2433,7 +2433,7 @@ func TestListClientRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list client role actions with invalid role id",
|
||||
@@ -2451,7 +2451,7 @@ func TestListClientRoleActions(t *testing.T) {
|
||||
clientID: clientID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2545,7 +2545,7 @@ func TestRemoveClientRoleActions(t *testing.T) {
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove client role actions with invalid role id",
|
||||
@@ -2565,7 +2565,7 @@ func TestRemoveClientRoleActions(t *testing.T) {
|
||||
roleID: roleID,
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2651,7 +2651,7 @@ func TestRemoveAllClientRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove all client role actions with invalid role id",
|
||||
@@ -2669,7 +2669,7 @@ func TestRemoveAllClientRoleActions(t *testing.T) {
|
||||
clientID: clientID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2770,7 +2770,7 @@ func TestAddClientRoleMembers(t *testing.T) {
|
||||
roleID: roleID,
|
||||
members: members,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "add client role members with invalid role id",
|
||||
@@ -2792,7 +2792,7 @@ func TestAddClientRoleMembers(t *testing.T) {
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2915,7 +2915,7 @@ func TestListClientRoleMembers(t *testing.T) {
|
||||
},
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list client role members with invalid role id",
|
||||
@@ -2941,7 +2941,7 @@ func TestListClientRoleMembers(t *testing.T) {
|
||||
},
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3035,7 +3035,7 @@ func TestRemoveClientRoleMembers(t *testing.T) {
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
members: members,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove client role members with invalid role id",
|
||||
@@ -3055,7 +3055,7 @@ func TestRemoveClientRoleMembers(t *testing.T) {
|
||||
roleID: roleID,
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3141,7 +3141,7 @@ func TestRemoveAllClientRoleMembers(t *testing.T) {
|
||||
domainID: domainID,
|
||||
clientID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove all client role members with invalid role id",
|
||||
@@ -3159,7 +3159,7 @@ func TestRemoveAllClientRoleMembers(t *testing.T) {
|
||||
clientID: clientID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+16
-16
@@ -149,7 +149,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
svcRes: domains.Domain{},
|
||||
svcErr: nil,
|
||||
response: sdk.Domain{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create domain with response that cannot be unmarshalled",
|
||||
@@ -165,7 +165,7 @@ func TestCreateDomain(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Domain{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -294,7 +294,7 @@ func TestUpdateDomain(t *testing.T) {
|
||||
svcRes: domains.Domain{},
|
||||
svcErr: nil,
|
||||
response: sdk.Domain{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update domain with response that cannot be unmarshalled",
|
||||
@@ -313,7 +313,7 @@ func TestUpdateDomain(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Domain{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -439,7 +439,7 @@ func TestViewDomain(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Domain{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -568,7 +568,7 @@ func TestListDomians(t *testing.T) {
|
||||
svcRes: domains.DomainsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.DomainsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list domains with request that cannot be marshalled",
|
||||
@@ -592,7 +592,7 @@ func TestListDomians(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.DomainsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -938,7 +938,7 @@ func TestCreateDomainRole(t *testing.T) {
|
||||
svcRes: roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1487,7 +1487,7 @@ func TestAddDomainRoleActions(t *testing.T) {
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1589,7 +1589,7 @@ func TestListDomainRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1695,7 +1695,7 @@ func TestRemoveDomainRoleActions(t *testing.T) {
|
||||
roleID: roleID,
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1791,7 +1791,7 @@ func TestRemoveAllDomainRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1906,7 +1906,7 @@ func TestAddDomainRoleMembers(t *testing.T) {
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2047,7 +2047,7 @@ func TestListDomainRoleMembers(t *testing.T) {
|
||||
},
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2153,7 +2153,7 @@ func TestRemoveDomainRoleMembers(t *testing.T) {
|
||||
roleID: roleID,
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2249,7 +2249,7 @@ func TestRemoveAllDomainRoleMembers(t *testing.T) {
|
||||
domainID: domainID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+52
-52
@@ -211,7 +211,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create group with name that is too long",
|
||||
@@ -226,7 +226,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create group with request that cannot be marshalled",
|
||||
@@ -243,7 +243,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "create group with service response that cannot be unmarshalled",
|
||||
@@ -262,7 +262,7 @@ func TestCreateGroup(t *testing.T) {
|
||||
svcRes: uGroup,
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -425,7 +425,7 @@ func TestListGroups(t *testing.T) {
|
||||
svcRes: groups.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list groups with given name",
|
||||
@@ -480,7 +480,7 @@ func TestListGroups(t *testing.T) {
|
||||
svcRes: groups.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list groups with service response that cannot be unmarshalled",
|
||||
@@ -513,7 +513,7 @@ func TestListGroups(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -615,9 +615,9 @@ func TestViewGroup(t *testing.T) {
|
||||
withRoles: false,
|
||||
groupID: wrongID,
|
||||
svcRes: groups.Group{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
|
||||
},
|
||||
{
|
||||
desc: "view group with service response that cannot be unmarshalled",
|
||||
@@ -634,7 +634,7 @@ func TestViewGroup(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
{
|
||||
desc: "view group with empty id",
|
||||
@@ -822,7 +822,7 @@ func TestUpdateGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update group with service response that cannot be unmarshalled",
|
||||
@@ -849,7 +849,7 @@ func TestUpdateGroup(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -984,7 +984,7 @@ func TestUpdateGroupTags(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update group tags with a response that can't be unmarshalled",
|
||||
@@ -1001,7 +1001,7 @@ func TestUpdateGroupTags(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1096,7 +1096,7 @@ func TestEnableGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "enable group with service response that cannot be unmarshalled",
|
||||
@@ -1112,7 +1112,7 @@ func TestEnableGroup(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1207,7 +1207,7 @@ func TestDisableGroup(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disable group with service response that cannot be unmarshalled",
|
||||
@@ -1223,7 +1223,7 @@ func TestDisableGroup(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1389,7 +1389,7 @@ func TestSetGroupParent(t *testing.T) {
|
||||
groupID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "set group parent with empty parent id",
|
||||
@@ -1398,7 +1398,7 @@ func TestSetGroupParent(t *testing.T) {
|
||||
groupID: groupID,
|
||||
parentID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1485,7 +1485,7 @@ func TestRemoveGroupParent(t *testing.T) {
|
||||
groupID: "",
|
||||
parentID: parentID,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1572,7 +1572,7 @@ func TestAddChildrenGroups(t *testing.T) {
|
||||
groupID: "",
|
||||
childrenIDs: []string{childID},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "add children group with empty children ids",
|
||||
@@ -1581,7 +1581,7 @@ func TestAddChildrenGroups(t *testing.T) {
|
||||
groupID: groupID,
|
||||
childrenIDs: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingChildrenGroupIDs), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingChildrenGroupIDs, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1668,7 +1668,7 @@ func TestRemoveChildrenGroups(t *testing.T) {
|
||||
groupID: "",
|
||||
childrenIDs: []string{childID},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove children group with empty children ids",
|
||||
@@ -1677,7 +1677,7 @@ func TestRemoveChildrenGroups(t *testing.T) {
|
||||
groupID: groupID,
|
||||
childrenIDs: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingChildrenGroupIDs), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingChildrenGroupIDs, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1757,7 +1757,7 @@ func TestRemoveAllChildrenGroups(t *testing.T) {
|
||||
token: validToken,
|
||||
groupID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1925,7 +1925,7 @@ func TestListChildrenGroups(t *testing.T) {
|
||||
svcRes: groups.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list children groups with given metadata",
|
||||
@@ -1978,7 +1978,7 @@ func TestListChildrenGroups(t *testing.T) {
|
||||
svcRes: groups.Page{},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list children groups with service response that cannot be unmarshalled",
|
||||
@@ -2010,7 +2010,7 @@ func TestListChildrenGroups(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2186,7 +2186,7 @@ func TestHierarchy(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.GroupsHierarchyPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2307,7 +2307,7 @@ func TestCreateGroupRole(t *testing.T) {
|
||||
svcRes: roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "create group role with empty role name",
|
||||
@@ -2322,7 +2322,7 @@ func TestCreateGroupRole(t *testing.T) {
|
||||
svcRes: roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2454,7 +2454,7 @@ func TestListGroupRoles(t *testing.T) {
|
||||
svcRes: roles.RolePage{},
|
||||
svcErr: nil,
|
||||
response: sdk.RolesPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2560,7 +2560,7 @@ func TestViewGroupRole(t *testing.T) {
|
||||
svcRes: roles.Role{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "view group role with invalid role id",
|
||||
@@ -2690,7 +2690,7 @@ func TestUpdateGroupRole(t *testing.T) {
|
||||
svcRes: roles.Role{},
|
||||
svcErr: nil,
|
||||
response: sdk.Role{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2778,7 +2778,7 @@ func TestDeleteGroupRole(t *testing.T) {
|
||||
domainID: domainID,
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "delete group role with invalid role id",
|
||||
@@ -2889,7 +2889,7 @@ func TestAddGroupRoleActions(t *testing.T) {
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "add group role actions with invalid role id",
|
||||
@@ -2911,7 +2911,7 @@ func TestAddGroupRoleActions(t *testing.T) {
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3004,7 +3004,7 @@ func TestListGroupRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list group role actions with invalid role id",
|
||||
@@ -3022,7 +3022,7 @@ func TestListGroupRoleActions(t *testing.T) {
|
||||
groupID: groupID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3117,7 +3117,7 @@ func TestRemoveGroupRoleActions(t *testing.T) {
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
actions: actions,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove group role actions with invalid role id",
|
||||
@@ -3137,7 +3137,7 @@ func TestRemoveGroupRoleActions(t *testing.T) {
|
||||
roleID: roleID,
|
||||
actions: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3224,7 +3224,7 @@ func TestRemoveAllGroupRoleActions(t *testing.T) {
|
||||
domainID: domainID,
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove all group role actions with invalid role id",
|
||||
@@ -3242,7 +3242,7 @@ func TestRemoveAllGroupRoleActions(t *testing.T) {
|
||||
groupID: groupID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3344,7 +3344,7 @@ func TestAddGroupRoleMembers(t *testing.T) {
|
||||
roleID: roleID,
|
||||
members: members,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "add group role members with invalid role id",
|
||||
@@ -3366,7 +3366,7 @@ func TestAddGroupRoleMembers(t *testing.T) {
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
response: []string{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3490,7 +3490,7 @@ func TestListGroupRoleMembers(t *testing.T) {
|
||||
},
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list group role members with invalid role id",
|
||||
@@ -3516,7 +3516,7 @@ func TestListGroupRoleMembers(t *testing.T) {
|
||||
},
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3611,7 +3611,7 @@ func TestRemoveGroupRoleMembers(t *testing.T) {
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
members: members,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove group role members with invalid role id",
|
||||
@@ -3631,7 +3631,7 @@ func TestRemoveGroupRoleMembers(t *testing.T) {
|
||||
roleID: roleID,
|
||||
members: []string{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3718,7 +3718,7 @@ func TestRemoveAllGroupRoleMembers(t *testing.T) {
|
||||
domainID: domainID,
|
||||
groupID: "",
|
||||
roleID: roleID,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "remove all group role members with invalid role id",
|
||||
@@ -3736,7 +3736,7 @@ func TestRemoveAllGroupRoleMembers(t *testing.T) {
|
||||
groupID: groupID,
|
||||
roleID: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestSendInvitation(t *testing.T) {
|
||||
},
|
||||
svcReq: domains.Invitation{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "send invitation with empty role ID",
|
||||
@@ -97,7 +97,7 @@ func TestSendInvitation(t *testing.T) {
|
||||
},
|
||||
svcReq: domains.Invitation{},
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "send inviation with invalid domainID",
|
||||
@@ -218,7 +218,7 @@ func TestListInvitation(t *testing.T) {
|
||||
svcRes: domains.InvitationPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.InvitationPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
|
||||
@@ -225,7 +225,7 @@ func TestRetrieveJournal(t *testing.T) {
|
||||
svcRes: journal.JournalsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.JournalsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidEntityType), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidEntityType, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "retrieve journal with empty entity ID",
|
||||
@@ -273,7 +273,7 @@ func TestRetrieveJournal(t *testing.T) {
|
||||
svcRes: journal.JournalsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.JournalsPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "retrieve journal with invalid page metadata",
|
||||
@@ -292,7 +292,7 @@ func TestRetrieveJournal(t *testing.T) {
|
||||
svcRes: journal.JournalsPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.JournalsPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "retrieve journal with response that cannot be unmarshalled",
|
||||
@@ -325,7 +325,7 @@ func TestRetrieveJournal(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.JournalsPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestSendMessage(t *testing.T) {
|
||||
authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""},
|
||||
authErr: nil,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptyMessage), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrEmptyMessage, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "publish message with channel subtopic",
|
||||
|
||||
@@ -85,7 +85,7 @@ func TestIssueToken(t *testing.T) {
|
||||
svcRes: &grpcTokenV1.Token{},
|
||||
svcErr: nil,
|
||||
response: sdk.Token{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsernameEmail), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsernameEmail, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "issue token with empty secret",
|
||||
@@ -96,7 +96,7 @@ func TestIssueToken(t *testing.T) {
|
||||
svcRes: &grpcTokenV1.Token{},
|
||||
svcErr: nil,
|
||||
response: sdk.Token{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
|
||||
+45
-45
@@ -136,7 +136,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsername), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsername, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "register user with first name too long",
|
||||
@@ -151,7 +151,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "register user with empty userName",
|
||||
@@ -171,7 +171,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsername), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsername, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "register user with empty secret",
|
||||
@@ -191,7 +191,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "register user with secret that is too short",
|
||||
@@ -211,7 +211,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "register a user with request that can't be marshalled",
|
||||
@@ -232,7 +232,7 @@ func TestCreateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "register a user with response that can't be unmarshalled",
|
||||
@@ -254,7 +254,7 @@ func TestCreateUser(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -412,7 +412,7 @@ func TestListUsers(t *testing.T) {
|
||||
svcRes: users.UsersPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.UsersPage{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "list users with given metadata",
|
||||
@@ -523,7 +523,7 @@ func TestListUsers(t *testing.T) {
|
||||
svcRes: users.UsersPage{},
|
||||
svcErr: nil,
|
||||
response: sdk.UsersPage{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "list users with response that can't be unmarshalled",
|
||||
@@ -553,7 +553,7 @@ func TestListUsers(t *testing.T) {
|
||||
},
|
||||
},
|
||||
response: sdk.UsersPage{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -671,7 +671,7 @@ func TestSearchUsers(t *testing.T) {
|
||||
Limit: limit,
|
||||
FirstName: "",
|
||||
},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptySearchQuery), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrEmptySearchQuery, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "search for users with invalid length of query",
|
||||
@@ -681,7 +681,7 @@ func TestSearchUsers(t *testing.T) {
|
||||
Limit: limit,
|
||||
Username: "a",
|
||||
},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrLenSearchQuery, apiutil.ErrValidation), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrValidation, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "search for users with invalid limit",
|
||||
@@ -691,7 +691,7 @@ func TestSearchUsers(t *testing.T) {
|
||||
Limit: 0,
|
||||
Username: "user_10",
|
||||
},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -761,9 +761,9 @@ func TestViewUser(t *testing.T) {
|
||||
token: validToken,
|
||||
userID: wrongID,
|
||||
svcRes: users.User{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
|
||||
},
|
||||
{
|
||||
desc: "view user with empty id",
|
||||
@@ -788,7 +788,7 @@ func TestViewUser(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -867,7 +867,7 @@ func TestUserProfile(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1009,7 +1009,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update user with response that can't be unmarshalled",
|
||||
@@ -1031,7 +1031,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1155,7 +1155,7 @@ func TestUpdateUserTags(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update user tags with request that can't be marshalled",
|
||||
@@ -1170,7 +1170,7 @@ func TestUpdateUserTags(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update user tags with response that can't be unmarshalled",
|
||||
@@ -1192,7 +1192,7 @@ func TestUpdateUserTags(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1318,7 +1318,7 @@ func TestUpdateUserEmail(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update email with response that can't be unmarshalled",
|
||||
@@ -1340,7 +1340,7 @@ func TestUpdateUserEmail(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1398,9 +1398,9 @@ func TestResetPasswordRequest(t *testing.T) {
|
||||
desc: "reset password request with invalid email",
|
||||
email: "invalidemail",
|
||||
svcRes: users.User{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
issueRes: &grpcTokenV1.Token{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
|
||||
},
|
||||
{
|
||||
desc: "reset password request with empty email",
|
||||
@@ -1408,7 +1408,7 @@ func TestResetPasswordRequest(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
issueRes: &grpcTokenV1.Token{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingEmail), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingEmail, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1478,7 +1478,7 @@ func TestResetPassword(t *testing.T) {
|
||||
newPassword: "",
|
||||
confPassword: newPassword,
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "reset password with empty confirm password",
|
||||
@@ -1487,7 +1487,7 @@ func TestResetPassword(t *testing.T) {
|
||||
newPassword: newPassword,
|
||||
confPassword: "",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingConfPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingConfPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "reset password with new password not matching confirm password",
|
||||
@@ -1496,7 +1496,7 @@ func TestResetPassword(t *testing.T) {
|
||||
newPassword: newPassword,
|
||||
confPassword: "wrongPassword",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidResetPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidResetPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "reset password with weak password",
|
||||
@@ -1505,7 +1505,7 @@ func TestResetPassword(t *testing.T) {
|
||||
newPassword: "weak",
|
||||
confPassword: "weak",
|
||||
svcErr: nil,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1590,7 +1590,7 @@ func TestUpdatePassword(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update password with empty new password",
|
||||
@@ -1600,7 +1600,7 @@ func TestUpdatePassword(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update password with invalid new password",
|
||||
@@ -1610,7 +1610,7 @@ func TestUpdatePassword(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update password with invalid old password",
|
||||
@@ -1636,7 +1636,7 @@ func TestUpdatePassword(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1757,7 +1757,7 @@ func TestUpdateUserRole(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update user role with request that can't be marshalled",
|
||||
@@ -1772,7 +1772,7 @@ func TestUpdateUserRole(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update user role with response that can't be unmarshalled",
|
||||
@@ -1794,7 +1794,7 @@ func TestUpdateUserRole(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -1930,7 +1930,7 @@ func TestUpdateUsername(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update username with response that can't be unmarshalled",
|
||||
@@ -1958,7 +1958,7 @@ func TestUpdateUsername(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -2085,7 +2085,7 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "update profile picture with request that can't be marshalled",
|
||||
@@ -2100,7 +2100,7 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
|
||||
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
|
||||
},
|
||||
{
|
||||
desc: "update profile picture with response that can't be unmarshalled",
|
||||
@@ -2121,7 +2121,7 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -2284,7 +2284,7 @@ func TestDisableUser(t *testing.T) {
|
||||
svcRes: users.User{},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
|
||||
},
|
||||
{
|
||||
desc: "disable user with response that can't be unmarshalled",
|
||||
@@ -2299,7 +2299,7 @@ func TestDisableUser(t *testing.T) {
|
||||
},
|
||||
svcErr: nil,
|
||||
response: sdk.User{},
|
||||
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
|
||||
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
|
||||
+139
-175
@@ -135,30 +135,8 @@ func TestRegister(t *testing.T) {
|
||||
user: user,
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusConflict,
|
||||
err: svcerr.ErrConflict,
|
||||
},
|
||||
{
|
||||
desc: "register a new user with an empty token",
|
||||
user: user,
|
||||
token: "",
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "register a user with an invalid ID",
|
||||
user: users.User{
|
||||
ID: inValid,
|
||||
Email: "user@example.com",
|
||||
Credentials: users.Credentials{
|
||||
Secret: "12345678",
|
||||
},
|
||||
},
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: svcerr.ErrConflict,
|
||||
},
|
||||
{
|
||||
desc: "register a user that can't be marshalled",
|
||||
@@ -174,7 +152,7 @@ func TestRegister(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "register user with invalid status",
|
||||
@@ -206,7 +184,7 @@ func TestRegister(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "register user with invalid content type",
|
||||
@@ -214,7 +192,7 @@ func TestRegister(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "register user with empty request body",
|
||||
@@ -222,7 +200,7 @@ func TestRegister(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingFirstName,
|
||||
},
|
||||
{
|
||||
desc: "register user with invalid username",
|
||||
@@ -239,7 +217,7 @@ func TestRegister(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidUsername,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -323,7 +301,7 @@ func TestView(t *testing.T) {
|
||||
desc: "view user with invalid ID",
|
||||
token: validToken,
|
||||
id: inValid,
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
authnRes: verifiedSession,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
@@ -401,7 +379,7 @@ func TestViewProfile(t *testing.T) {
|
||||
desc: "view profile with service error",
|
||||
token: validToken,
|
||||
id: user.ID,
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
authnRes: verifiedSession,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
@@ -499,7 +477,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: "offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with limit",
|
||||
@@ -522,7 +500,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: "limit=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with limit greater than max",
|
||||
@@ -530,7 +508,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "list users with name",
|
||||
@@ -574,7 +552,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: "status=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: svcerr.ErrInvalidStatus,
|
||||
},
|
||||
{
|
||||
desc: "list users with duplicate status",
|
||||
@@ -626,7 +604,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: "metadata=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with duplicate metadata",
|
||||
@@ -658,20 +636,6 @@ func TestListUsers(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with list perms",
|
||||
token: validToken,
|
||||
listUsersResponse: users.UsersPage{
|
||||
Page: users.Page{
|
||||
Total: 1,
|
||||
},
|
||||
Users: []users.User{user},
|
||||
},
|
||||
query: "list_perms=true",
|
||||
status: http.StatusOK,
|
||||
authnRes: verifiedSession,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list users with duplicate list perms",
|
||||
token: validToken,
|
||||
@@ -702,14 +666,6 @@ func TestListUsers(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with duplicate list perms",
|
||||
token: validToken,
|
||||
query: "list_perms=true&list_perms=true",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list users with email",
|
||||
token: validToken,
|
||||
@@ -764,7 +720,7 @@ func TestListUsers(t *testing.T) {
|
||||
query: "dir=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
authnRes: verifiedSession,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidDirection,
|
||||
},
|
||||
{
|
||||
desc: "list users with duplicate order direction",
|
||||
@@ -908,7 +864,7 @@ func TestSearchUsers(t *testing.T) {
|
||||
desc: "serach users with service error",
|
||||
token: validToken,
|
||||
query: "username=username",
|
||||
status: http.StatusBadRequest,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
@@ -1028,7 +984,7 @@ func TestUpdate(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user with malformed data",
|
||||
@@ -1038,7 +994,7 @@ func TestUpdate(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update user with empty id",
|
||||
@@ -1047,8 +1003,8 @@ func TestUpdate(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1167,7 +1123,7 @@ func TestUpdateTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user tags with empty id",
|
||||
@@ -1177,7 +1133,7 @@ func TestUpdateTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "update user with malfomed data",
|
||||
@@ -1187,7 +1143,7 @@ func TestUpdateTags(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1334,7 +1290,7 @@ func TestUpdateEmail(t *testing.T) {
|
||||
contentType: "application/xml",
|
||||
token: validToken,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user email with malformed data",
|
||||
@@ -1349,7 +1305,7 @@ func TestUpdateEmail(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update user email with service error",
|
||||
@@ -1368,29 +1324,31 @@ func TestUpdateEmail(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/email", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/email", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateEmail", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateEmail", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1486,7 +1444,7 @@ func TestUpdateUsername(t *testing.T) {
|
||||
contentType: "application/xml",
|
||||
token: validToken,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user email with malformed data",
|
||||
@@ -1501,7 +1459,7 @@ func TestUpdateUsername(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update username with invalid username",
|
||||
@@ -1521,29 +1479,31 @@ func TestUpdateUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/username", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/username", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateUsername", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateUsername", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1624,7 +1584,7 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
contentType: "application/xml",
|
||||
token: validToken,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update profile picture with malformed data",
|
||||
@@ -1634,7 +1594,7 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update profile picture with failed to update",
|
||||
@@ -1652,29 +1612,31 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/picture", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/users/%s/picture", us.URL, tc.user.ID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateProfilePicture", mock.Anything, tc.authnRes, tc.user.ID, mock.Anything).Return(tc.user, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("UpdateProfilePicture", mock.Anything, tc.authnRes, tc.user.ID, mock.Anything).Return(tc.user, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1868,14 +1830,14 @@ func TestVerifyEmail(t *testing.T) {
|
||||
desc: "verify email with empty token",
|
||||
token: "",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrInvalidVerification,
|
||||
},
|
||||
{
|
||||
desc: "verify email with service error",
|
||||
token: validToken,
|
||||
status: http.StatusBadRequest,
|
||||
svcErr: svcerr.ErrMalformedEntity,
|
||||
err: svcerr.ErrMalformedEntity,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2101,7 +2063,7 @@ func TestUpdateRole(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user with malformed data",
|
||||
@@ -2111,7 +2073,7 @@ func TestUpdateRole(t *testing.T) {
|
||||
authnRes: verifiedSession,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "update user with service error",
|
||||
@@ -2251,7 +2213,7 @@ func TestUpdateSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update user secret with malformed data",
|
||||
@@ -2267,7 +2229,7 @@ func TestUpdateSecret(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2326,14 +2288,14 @@ func TestIssueToken(t *testing.T) {
|
||||
data: fmt.Sprintf(dataFormat, "", secret),
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingUsernameEmail,
|
||||
},
|
||||
{
|
||||
desc: "issue token with empty secret",
|
||||
data: fmt.Sprintf(dataFormat, validUsername, ""),
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMissingPass,
|
||||
},
|
||||
{
|
||||
desc: "issue token with invalid email",
|
||||
@@ -2347,14 +2309,14 @@ func TestIssueToken(t *testing.T) {
|
||||
data: fmt.Sprintf(dataFormat, validUsername, secret),
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "issue token with invalid contentype",
|
||||
data: fmt.Sprintf(dataFormat, "invalid", secret),
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2443,7 +2405,7 @@ func TestRefreshToken(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "refresh token with invalid contentype",
|
||||
@@ -2452,35 +2414,37 @@ func TestRefreshToken(t *testing.T) {
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrValidation,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/users/tokens/refresh", us.URL),
|
||||
contentType: tc.contentType,
|
||||
body: strings.NewReader(tc.data),
|
||||
token: tc.token,
|
||||
}
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("RefreshToken", mock.Anything, tc.authnRes, tc.token, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if tc.err != nil {
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/users/tokens/refresh", us.URL),
|
||||
contentType: tc.contentType,
|
||||
body: strings.NewReader(tc.data),
|
||||
token: tc.token,
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("RefreshToken", mock.Anything, tc.authnRes, tc.token, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if tc.err != nil {
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
|
||||
const (
|
||||
valid = "valid"
|
||||
invalid = "invalid"
|
||||
secret = "QJg58*aMan7j"
|
||||
name = "user"
|
||||
validEmail = "example@domain.com"
|
||||
|
||||
+11
-11
@@ -384,7 +384,7 @@ func decodeUpdateUser(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -399,7 +399,7 @@ func decodeUpdateUserTags(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -414,7 +414,7 @@ func decodeUpdateUserEmail(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -427,7 +427,7 @@ func decodeUpdateUserSecret(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := updateUserSecretReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -442,7 +442,7 @@ func decodeUpdateUsername(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -458,7 +458,7 @@ func decodeUpdateUserProfilePicture(_ context.Context, r *http.Request) (any, er
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -471,7 +471,7 @@ func decodePasswordResetRequest(_ context.Context, r *http.Request) (any, error)
|
||||
|
||||
var req passResetReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -484,7 +484,7 @@ func decodePasswordReset(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
var req resetTokenReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -499,7 +499,7 @@ func decodeUpdateUserRole(_ context.Context, r *http.Request) (any, error) {
|
||||
id: chi.URLParam(r, "id"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
var err error
|
||||
req.role, err = users.ToRole(req.Role)
|
||||
@@ -513,7 +513,7 @@ func decodeCredentials(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
req := loginUserReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
@@ -535,7 +535,7 @@ func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) {
|
||||
|
||||
var req createUserReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
const cost int = 10
|
||||
|
||||
var (
|
||||
errHashPassword = errors.New("generate hash from password failed")
|
||||
errComparePassword = errors.New("compare hash and password failed")
|
||||
errHashPassword = errors.NewServiceError("generate hash from password failed")
|
||||
errComparePassword = errors.NewServiceError("compare hash and password failed")
|
||||
)
|
||||
|
||||
var _ users.Hasher = (*bcryptHasher)(nil)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "clients_email_key":
|
||||
return errors.NewRequestError("email id already registered"), true
|
||||
case "clients_username_key":
|
||||
return errors.NewRequestError("username not available"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+39
-49
@@ -18,18 +18,20 @@ import (
|
||||
"github.com/absmach/supermq/pkg/postgres"
|
||||
"github.com/absmach/supermq/users"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
var pgDuplicateErrCode = "23505"
|
||||
|
||||
type userRepo struct {
|
||||
Repository users.UserRepository
|
||||
eh errors.Handler
|
||||
}
|
||||
|
||||
func NewRepository(db postgres.Database) users.Repository {
|
||||
errHandlerOptions := []errors.HandlerOption{
|
||||
postgres.WithDuplicateErrors(NewDuplicateErrors()),
|
||||
}
|
||||
return &userRepo{
|
||||
Repository: users.UserRepository{DB: db},
|
||||
eh: postgres.NewErrorHandler(errHandlerOptions...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,12 +42,12 @@ func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error
|
||||
|
||||
dbu, err := toDBUser(c)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, handleSaveError(repoerr.ErrCreateEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
@@ -54,40 +56,28 @@ func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error
|
||||
|
||||
dbu = DBUser{}
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
user, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func handleSaveError(wrapper, err error) error {
|
||||
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
|
||||
switch pqErr.ConstraintName {
|
||||
case "clients_email_key":
|
||||
return errors.ErrEmailAlreadyExists
|
||||
case "clients_username_key":
|
||||
return errors.ErrUsernameNotAvailable
|
||||
}
|
||||
}
|
||||
return postgres.HandleError(wrapper, err)
|
||||
}
|
||||
|
||||
func (repo *userRepo) CheckSuperAdmin(ctx context.Context, adminID string) error {
|
||||
q := "SELECT 1 FROM users WHERE id = $1 AND role = $2"
|
||||
rows, err := repo.Repository.DB.QueryContext(ctx, q, adminID, users.AdminRole)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -105,7 +95,7 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
|
||||
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -115,12 +105,12 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
|
||||
}
|
||||
|
||||
if err = rows.StructScan(&dbu); err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
user, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
@@ -129,7 +119,7 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
|
||||
func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.UsersPage, error) {
|
||||
query, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
|
||||
}
|
||||
|
||||
squery := applyOrdering(query, pm)
|
||||
@@ -140,26 +130,26 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
|
||||
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
|
||||
var items []users.User
|
||||
if !pm.OnlyTotal {
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, err
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
@@ -170,7 +160,7 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
@@ -242,12 +232,12 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
|
||||
func (repo *userRepo) update(ctx context.Context, user users.User, query string) (users.User, error) {
|
||||
dbu, err := toDBUser(user)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, query, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
@@ -257,7 +247,7 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string)
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
@@ -308,7 +298,7 @@ func (repo *userRepo) Delete(ctx context.Context, id string) error {
|
||||
|
||||
result, err := repo.Repository.DB.ExecContext(ctx, q, id)
|
||||
if err != nil {
|
||||
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
@@ -320,7 +310,7 @@ func (repo *userRepo) Delete(ctx context.Context, id string) error {
|
||||
func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) {
|
||||
query, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
|
||||
}
|
||||
|
||||
tq := query
|
||||
@@ -330,12 +320,12 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -343,7 +333,7 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
@@ -358,7 +348,7 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
@@ -381,7 +371,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
}
|
||||
query, err := PageQuery(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
|
||||
}
|
||||
squery := applyOrdering(query, pm)
|
||||
|
||||
@@ -389,11 +379,11 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
|
||||
}
|
||||
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@@ -401,12 +391,12 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
for rows.Next() {
|
||||
dbu := DBUser{}
|
||||
if err := rows.StructScan(&dbu); err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
c, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, err
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
|
||||
}
|
||||
|
||||
items = append(items, c)
|
||||
@@ -415,7 +405,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
|
||||
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
|
||||
if err != nil {
|
||||
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
page := users.UsersPage{
|
||||
@@ -441,14 +431,14 @@ func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbu = DBUser{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
@@ -468,14 +458,14 @@ func (repo *userRepo) RetrieveByUsername(ctx context.Context, username string) (
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
dbu = DBUser{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
|
||||
@@ -82,7 +82,6 @@ func TestUsersSave(t *testing.T) {
|
||||
user: externalUser,
|
||||
err: nil,
|
||||
},
|
||||
|
||||
{
|
||||
desc: "add user with duplicate user email",
|
||||
user: users.User{
|
||||
@@ -129,7 +128,7 @@ func TestUsersSave(t *testing.T) {
|
||||
Metadata: users.Metadata{},
|
||||
Status: users.EnabledStatus,
|
||||
},
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add user with invalid user name",
|
||||
@@ -145,7 +144,7 @@ func TestUsersSave(t *testing.T) {
|
||||
Metadata: users.Metadata{},
|
||||
Status: users.EnabledStatus,
|
||||
},
|
||||
err: errors.ErrMalformedEntity,
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "add user with a missing username",
|
||||
@@ -194,12 +193,14 @@ func TestUsersSave(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
rUser, err := repo.Save(context.Background(), tc.user)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
rUser.Credentials.Secret = tc.user.Credentials.Secret
|
||||
assert.Equal(t, tc.user, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, rUser))
|
||||
}
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
rUser, err := repo.Save(context.Background(), tc.user)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
rUser.Credentials.Secret = tc.user.Credentials.Secret
|
||||
assert.Equal(t, tc.user, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, rUser))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1307,7 +1308,7 @@ func TestUpdate(t *testing.T) {
|
||||
userReq: users.UserReq{
|
||||
Metadata: &malformedMetadata,
|
||||
},
|
||||
err: repoerr.ErrUpdateEntity,
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
},
|
||||
{
|
||||
desc: "update empty metadata for enabled user",
|
||||
@@ -1564,7 +1565,7 @@ func TestUpdateUsername(t *testing.T) {
|
||||
Username: user2.Credentials.Username,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrConflict,
|
||||
err: errors.ErrUsernameNotAvailable,
|
||||
},
|
||||
{
|
||||
desc: "for disabled user",
|
||||
@@ -1984,16 +1985,18 @@ func TestRetrieveByIDs(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
|
||||
assert.Equal(t, c.response.Total, response.Total)
|
||||
assert.Equal(t, c.response.Limit, response.Limit)
|
||||
assert.Equal(t, c.response.Offset, response.Offset)
|
||||
assert.ElementsMatch(t, response.Users, c.response.Users)
|
||||
default:
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
|
||||
}
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
|
||||
assert.Equal(t, c.response.Total, response.Total)
|
||||
assert.Equal(t, c.response.Limit, response.Limit)
|
||||
assert.Equal(t, c.response.Offset, response.Offset)
|
||||
assert.ElementsMatch(t, response.Users, c.response.Users)
|
||||
default:
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+25
-15
@@ -28,9 +28,11 @@ import (
|
||||
const defaultUsernamePrefix = "user"
|
||||
|
||||
var (
|
||||
errIssueToken = errors.New("failed to issue token")
|
||||
errRecoveryToken = errors.New("failed to generate password recovery token")
|
||||
errLoginDisableUser = errors.New("failed to login in disabled user")
|
||||
errIssueToken = errors.NewServiceError("failed to issue token")
|
||||
errRecoveryToken = errors.NewServiceError("failed to generate password recovery token")
|
||||
errLoginDisableUser = errors.NewAuthNError("failed to login in disabled user")
|
||||
errMatchUserVerification = errors.NewRequestError("user verification does not match with stored verification")
|
||||
errSimilarUpdateEmail = errors.NewRequestError("new email is similar to the current email")
|
||||
|
||||
usernameRegExp = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{34}[a-z0-9]$`)
|
||||
)
|
||||
@@ -65,28 +67,28 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
|
||||
|
||||
userID, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
return User{}, err
|
||||
return User{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
|
||||
}
|
||||
|
||||
if u.Credentials.Secret != "" {
|
||||
hash, err := svc.hasher.Hash(u.Credentials.Secret)
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, err)
|
||||
return User{}, errors.Wrap(svcerr.ErrHashPassword, err)
|
||||
}
|
||||
u.Credentials.Secret = hash
|
||||
}
|
||||
|
||||
if u.Status != DisabledStatus && u.Status != EnabledStatus {
|
||||
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidStatus)
|
||||
return User{}, svcerr.ErrInvalidStatus
|
||||
}
|
||||
if u.Role != UserRole && u.Role != AdminRole {
|
||||
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidRole)
|
||||
return User{}, svcerr.ErrInvalidRole
|
||||
}
|
||||
u.ID = userID
|
||||
u.CreatedAt = time.Now().UTC()
|
||||
|
||||
if err := svc.addUserPolicy(ctx, u.ID, u.Role); err != nil {
|
||||
return User{}, err
|
||||
return User{}, errors.Wrap(svcerr.ErrAddPolicies, err)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
@@ -105,7 +107,7 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
|
||||
func (svc service) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if !dbUser.VerifiedAt.IsZero() {
|
||||
@@ -114,7 +116,7 @@ func (svc service) SendVerification(ctx context.Context, session authn.Session)
|
||||
|
||||
uv, err := svc.users.RetrieveUserVerification(ctx, dbUser.ID, dbUser.Email)
|
||||
if err != nil && err != repoerr.ErrNotFound {
|
||||
return err
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
if err = uv.Valid(); err != nil {
|
||||
@@ -150,7 +152,7 @@ func (svc service) VerifyEmail(ctx context.Context, token string) (User, error)
|
||||
}
|
||||
|
||||
if err := stored.Match(received); err != nil {
|
||||
return User{}, err
|
||||
return User{}, errors.Wrap(errMatchUserVerification, err)
|
||||
}
|
||||
|
||||
if err := stored.Valid(); err != nil {
|
||||
@@ -219,8 +221,12 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
|
||||
if dbUser.Status == DisabledStatus {
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
|
||||
}
|
||||
token, err := svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
|
||||
}
|
||||
|
||||
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
|
||||
@@ -375,7 +381,7 @@ func (svc service) UpdateEmail(ctx context.Context, session authn.Session, userI
|
||||
return User{}, svcerr.ErrExternalAuthProviderCouldNotUpdate
|
||||
}
|
||||
if oldUsr.Email == email {
|
||||
return User{}, fmt.Errorf("current email is same as update requested email")
|
||||
return User{}, errSimilarUpdateEmail
|
||||
}
|
||||
|
||||
usr := User{
|
||||
@@ -409,7 +415,11 @@ func (svc service) SendPasswordReset(ctx context.Context, email string) error {
|
||||
return errors.Wrap(errRecoveryToken, err)
|
||||
}
|
||||
|
||||
return svc.email.SendPasswordReset([]string{email}, user.Credentials.Username, token.AccessToken)
|
||||
if err := svc.email.SendPasswordReset([]string{email}, user.Credentials.Username, token.AccessToken); err != nil {
|
||||
return errors.NewInternalErrorWithErr(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) ResetSecret(ctx context.Context, session authn.Session, secret string) error {
|
||||
@@ -548,7 +558,7 @@ func (svc service) changeUserStatus(ctx context.Context, session authn.Session,
|
||||
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if dbu.Status == user.Status {
|
||||
return User{}, errors.ErrStatusAlreadyAssigned
|
||||
return User{}, svcerr.ErrStatusAlreadyAssigned
|
||||
}
|
||||
user.UpdatedBy = session.UserID
|
||||
|
||||
|
||||
+296
-264
@@ -165,7 +165,7 @@ func TestRegister(t *testing.T) {
|
||||
Secret: strings.Repeat("a", 73),
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: errHashPassword,
|
||||
},
|
||||
{
|
||||
desc: "register a new user with invalid status",
|
||||
@@ -221,74 +221,76 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
|
||||
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
|
||||
repoCall := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
|
||||
expected, err := svc.Register(context.Background(), authn.Session{}, tc.user, true)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
tc.user.ID = expected.ID
|
||||
tc.user.CreatedAt = expected.CreatedAt
|
||||
tc.user.UpdatedAt = expected.UpdatedAt
|
||||
tc.user.Credentials.Secret = expected.Credentials.Secret
|
||||
tc.user.UpdatedBy = expected.UpdatedBy
|
||||
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall1.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
|
||||
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
|
||||
repoCall := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
|
||||
expected, err := svc.Register(context.Background(), authn.Session{}, tc.user, true)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
tc.user.ID = expected.ID
|
||||
tc.user.CreatedAt = expected.CreatedAt
|
||||
tc.user.UpdatedAt = expected.UpdatedAt
|
||||
tc.user.Credentials.Secret = expected.Credentials.Secret
|
||||
tc.user.UpdatedBy = expected.UpdatedBy
|
||||
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
|
||||
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall1.Unset()
|
||||
})
|
||||
}
|
||||
|
||||
svc, _, cRepo, policies, _ = newService()
|
||||
|
||||
cases2 := []struct {
|
||||
desc string
|
||||
user users.User
|
||||
session authn.Session
|
||||
addPoliciesResponseErr error
|
||||
deletePoliciesResponseErr error
|
||||
saveErr error
|
||||
checkSuperAdminErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "register new user successfully as admin",
|
||||
user: user,
|
||||
session: authn.Session{UserID: validID, SuperAdmin: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "register a new user as admin with failed check on super admin",
|
||||
user: user,
|
||||
session: authn.Session{UserID: validID, SuperAdmin: false},
|
||||
checkSuperAdminErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases2 {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
|
||||
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
|
||||
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
|
||||
expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
tc.user.ID = expected.ID
|
||||
tc.user.CreatedAt = expected.CreatedAt
|
||||
tc.user.UpdatedAt = expected.UpdatedAt
|
||||
tc.user.Credentials.Secret = expected.Credentials.Secret
|
||||
tc.user.UpdatedBy = expected.UpdatedBy
|
||||
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall1.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall1.Unset()
|
||||
repoCall.Unset()
|
||||
}
|
||||
// cases2 := []struct {
|
||||
// desc string
|
||||
// user users.User
|
||||
// session authn.Session
|
||||
// addPoliciesResponseErr error
|
||||
// deletePoliciesResponseErr error
|
||||
// saveErr error
|
||||
// checkSuperAdminErr error
|
||||
// err error
|
||||
// }{
|
||||
// {
|
||||
// desc: "register new user successfully as admin",
|
||||
// user: user,
|
||||
// session: authn.Session{UserID: validID, SuperAdmin: true},
|
||||
// err: nil,
|
||||
// },
|
||||
// {
|
||||
// desc: "register a new user as admin with failed check on super admin",
|
||||
// user: user,
|
||||
// session: authn.Session{UserID: validID, SuperAdmin: false},
|
||||
// checkSuperAdminErr: svcerr.ErrAuthorization,
|
||||
// err: svcerr.ErrAuthorization,
|
||||
// },
|
||||
// }
|
||||
// for _, tc := range cases2 {
|
||||
// repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
// policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
|
||||
// policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
|
||||
// repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
|
||||
// expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false)
|
||||
// assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
// if err == nil {
|
||||
// tc.user.ID = expected.ID
|
||||
// tc.user.CreatedAt = expected.CreatedAt
|
||||
// tc.user.UpdatedAt = expected.UpdatedAt
|
||||
// tc.user.Credentials.Secret = expected.Credentials.Secret
|
||||
// tc.user.UpdatedBy = expected.UpdatedBy
|
||||
// assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
|
||||
// ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
|
||||
// assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
|
||||
// }
|
||||
// repoCall1.Unset()
|
||||
// policyCall.Unset()
|
||||
// policyCall1.Unset()
|
||||
// repoCall.Unset()
|
||||
// }
|
||||
}
|
||||
|
||||
func TestViewUser(t *testing.T) {
|
||||
@@ -349,18 +351,20 @@ func TestViewUser(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
rUser, err := svc.View(context.Background(), authn.Session{UserID: tc.reqUserID}, tc.userID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
tc.response.Credentials.Secret = ""
|
||||
assert.Equal(t, tc.response, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.userID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall1.Unset()
|
||||
repoCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
rUser, err := svc.View(context.Background(), authn.Session{UserID: tc.reqUserID}, tc.userID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
tc.response.Credentials.Secret = ""
|
||||
assert.Equal(t, tc.response, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.userID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall1.Unset()
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -430,17 +434,19 @@ func TestListUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.superAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListUsers(context.Background(), authn.Session{UserID: user.ID}, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.superAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
|
||||
page, err := svc.ListUsers(context.Background(), authn.Session{UserID: user.ID}, tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -494,11 +500,13 @@ func TestSearchUsers(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("SearchUsers", context.Background(), mock.Anything).Return(tc.response, tc.responseErr)
|
||||
page, err := svc.SearchUsers(context.Background(), tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
repoCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("SearchUsers", context.Background(), mock.Anything).Return(tc.response, tc.responseErr)
|
||||
page, err := svc.SearchUsers(context.Background(), tc.page)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -673,19 +681,21 @@ func TestUpdateUser(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateResponse, tc.updateErr)
|
||||
updatedUser, err := svc.Update(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateResponse, tc.updateErr)
|
||||
updatedUser, err := svc.Update(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -750,18 +760,20 @@ func TestUpdateTags(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateUserTagsResponse, tc.updateUserTagsErr)
|
||||
updatedUser, err := svc.UpdateTags(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateUserTagsResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUserTagsResponse, updatedUser))
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateUserTagsResponse, tc.updateUserTagsErr)
|
||||
updatedUser, err := svc.UpdateTags(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateUserTagsResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUserTagsResponse, updatedUser))
|
||||
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -843,22 +855,24 @@ func TestUpdateRole(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
|
||||
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
|
||||
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
|
||||
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
|
||||
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
|
||||
|
||||
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall1.Unset()
|
||||
repoCall1.Unset()
|
||||
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
policyCall1.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -919,7 +933,7 @@ func TestUpdateSecret(t *testing.T) {
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update user secret with invalod old secret",
|
||||
desc: "update user secret with invalid old secret",
|
||||
oldSecret: "invalid",
|
||||
newSecret: newSecret,
|
||||
session: authn.Session{UserID: user.ID},
|
||||
@@ -934,7 +948,7 @@ func TestUpdateSecret(t *testing.T) {
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: user,
|
||||
retrieveByEmailResponse: rUser,
|
||||
err: repoerr.ErrMalformedEntity,
|
||||
err: errHashPassword,
|
||||
},
|
||||
{
|
||||
desc: "update user secret with failed to update secret",
|
||||
@@ -950,25 +964,27 @@ func TestUpdateSecret(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("RetrieveByID", context.Background(), user.ID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall1 := cRepo.On("RetrieveByUsername", context.Background(), user.Credentials.Username).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr)
|
||||
repoCall2 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
|
||||
authCall := authUser.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
|
||||
updatedUser, err := svc.UpdateSecret(context.Background(), tc.session, tc.oldSecret, tc.newSecret)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.response.ID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall1.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.response.Credentials.Username)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
authCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("RetrieveByID", context.Background(), user.ID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall1 := cRepo.On("RetrieveByUsername", context.Background(), user.Credentials.Username).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr)
|
||||
repoCall2 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
|
||||
authCall := authUser.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
|
||||
updatedUser, err := svc.UpdateSecret(context.Background(), tc.session, tc.oldSecret, tc.newSecret)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.response, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.response.ID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall1.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.response.Credentials.Username)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1049,20 +1065,22 @@ func TestUpdateEmail(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
|
||||
if tc.err == nil && user2.Email != tc.email {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
user2.Email = tc.email
|
||||
}
|
||||
repoCall.Unset()
|
||||
repocall2.Unset()
|
||||
repoCall1.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
|
||||
if tc.err == nil && user2.Email != tc.email {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
user2.Email = tc.email
|
||||
}
|
||||
repoCall.Unset()
|
||||
repocall2.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1152,19 +1170,21 @@ func TestUpdateProfilePicture(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateProfilePicResponse, tc.updateProfilePicErr)
|
||||
updatedUser, err := svc.UpdateProfilePicture(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateProfilePicResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateProfilePicResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateProfilePicResponse, tc.updateProfilePicErr)
|
||||
updatedUser, err := svc.UpdateProfilePicture(context.Background(), tc.session, tc.userID, tc.userReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateProfilePicResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateProfilePicResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1224,17 +1244,19 @@ func TestUpdateUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("UpdateUsername", context.Background(), mock.Anything).Return(tc.updateUsernameResponse, tc.updateUsernameErr)
|
||||
updatedUser, err := svc.UpdateUsername(context.Background(), tc.session, tc.user.ID, tc.user.Credentials.Username)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateUsernameResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUsernameResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateUsername", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("UpdateUsername was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("UpdateUsername", context.Background(), mock.Anything).Return(tc.updateUsernameResponse, tc.updateUsernameErr)
|
||||
updatedUser, err := svc.UpdateUsername(context.Background(), tc.session, tc.user.ID, tc.user.Credentials.Username)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateUsernameResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUsernameResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateUsername", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("UpdateUsername was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1287,7 +1309,7 @@ func TestEnableUser(t *testing.T) {
|
||||
id: enabledUser1.ID,
|
||||
user: enabledUser1,
|
||||
retrieveByIDResponse: enabledUser1,
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "enable disabled user with failed to change status",
|
||||
@@ -1301,21 +1323,23 @@ func TestEnableUser(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
|
||||
_, err := svc.Enable(context.Background(), authn.Session{}, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
_, err := svc.Enable(context.Background(), authn.Session{}, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1368,7 +1392,7 @@ func TestDisableUser(t *testing.T) {
|
||||
id: disabledUser1.ID,
|
||||
user: disabledUser1,
|
||||
retrieveByIDResponse: disabledUser1,
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "disable enabled user with failed to change status",
|
||||
@@ -1381,21 +1405,23 @@ func TestDisableUser(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
|
||||
_, err := svc.Disable(context.Background(), authn.Session{}, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
_, err := svc.Disable(context.Background(), authn.Session{}, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1445,7 +1471,7 @@ func TestDeleteUser(t *testing.T) {
|
||||
user: deletedUser1,
|
||||
session: authn.Session{UserID: validID, SuperAdmin: true},
|
||||
retrieveByIDResponse: deletedUser1,
|
||||
err: errors.ErrStatusAlreadyAssigned,
|
||||
err: svcerr.ErrStatusAlreadyAssigned,
|
||||
},
|
||||
{
|
||||
desc: "delete enabled user with failed to change status",
|
||||
@@ -1460,20 +1486,22 @@ func TestDeleteUser(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall2 := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall3 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall4 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
err := svc.Delete(context.Background(), tc.session, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall3.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall4.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
repoCall4.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall2 := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall3 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
|
||||
repoCall4 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
|
||||
err := svc.Delete(context.Background(), tc.session, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil {
|
||||
ok := repoCall3.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = repoCall4.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall2.Unset()
|
||||
repoCall3.Unset()
|
||||
repoCall4.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1542,20 +1570,22 @@ func TestIssueToken(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
|
||||
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
|
||||
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
|
||||
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
|
||||
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
|
||||
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
|
||||
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
|
||||
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
|
||||
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
|
||||
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
|
||||
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1619,20 +1649,22 @@ func TestRefreshToken(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
authCall := authsvc.On("Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr)
|
||||
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
|
||||
token, err := svc.RefreshToken(context.Background(), tc.session, validToken)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
|
||||
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
|
||||
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken})
|
||||
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
|
||||
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
authCall := authsvc.On("Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr)
|
||||
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
|
||||
token, err := svc.RefreshToken(context.Background(), tc.session, validToken)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
|
||||
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
|
||||
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken})
|
||||
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
|
||||
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+18
-13
@@ -54,7 +54,6 @@ var (
|
||||
errClientNotInitialized = errors.New("client is not initialized")
|
||||
errMissingTopicPub = errors.New("failed to publish due to missing topic")
|
||||
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
|
||||
errFailedPublish = errors.New("failed to publish")
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -164,7 +163,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with nil session",
|
||||
@@ -228,7 +227,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with unauthorized token",
|
||||
@@ -275,7 +274,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with b64 encoded credentials",
|
||||
@@ -306,7 +305,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with health check topic successfully",
|
||||
@@ -330,7 +329,7 @@ func TestAuthPublish(t *testing.T) {
|
||||
authKey: clientKey,
|
||||
payload: &payload,
|
||||
status: http.StatusBadRequest,
|
||||
err: errors.Wrap(errFailedPublish, messaging.ErrMalformedTopic),
|
||||
err: messaging.ErrMalformedTopic,
|
||||
clientType: policies.ClientType,
|
||||
},
|
||||
}
|
||||
@@ -359,7 +358,9 @@ func TestAuthPublish(t *testing.T) {
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
|
||||
}
|
||||
authCall.Unset()
|
||||
clientsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
@@ -451,7 +452,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with empty topics",
|
||||
@@ -512,7 +513,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
authNRes1: smqauthn.Session{},
|
||||
authNErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with unauthorized token",
|
||||
@@ -556,7 +557,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "publish with b64 encoded credentials",
|
||||
@@ -585,7 +586,7 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
|
||||
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
|
||||
status: http.StatusUnauthorized,
|
||||
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "subscribe with health check topic successfully",
|
||||
@@ -636,7 +637,9 @@ func TestAuthSubscribe(t *testing.T) {
|
||||
if ok {
|
||||
assert.Equal(t, tc.status, hpe.StatusCode())
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
|
||||
}
|
||||
authCall.Unset()
|
||||
clientsCall.Unset()
|
||||
channelsCall.Unset()
|
||||
@@ -727,7 +730,9 @@ func TestPublish(t *testing.T) {
|
||||
}
|
||||
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
err := handler.Publish(ctx, &tc.topic, &tc.payload)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
|
||||
if tc.err != nil {
|
||||
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
|
||||
}
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user