SMQ-1744 - Error handling with TypedError created on top existing Error (#3170)

Signed-off-by: Arvindh <arvindh91@gmail.com>
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
Co-authored-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Arvindh
2025-12-22 13:01:52 +05:30
committed by GitHub
parent 1355bc8bb7
commit 3fcf2e5369
71 changed files with 1919 additions and 1623 deletions
+37 -137
View File
@@ -16,7 +16,6 @@ import (
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/users"
"github.com/gofrs/uuid/v5"
)
@@ -184,146 +183,47 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
return
}
var wrapper error
if errors.Contains(err, apiutil.ErrValidation) {
wrapper, err = errors.Unwrap(err)
}
switch {
case errors.Contains(err, errors.ErrTryAgain):
w.WriteHeader(http.StatusUnprocessableEntity)
case errors.Contains(err, errors.ErrEmailAlreadyExists),
errors.Contains(err, errors.ErrUsernameNotAvailable),
errors.Contains(err, errors.ErrRouteNotAvailable),
errors.Contains(err, errors.ErrChannelRouteNotAvailable),
errors.Contains(err, errors.ErrDomainRouteNotAvailable),
errors.Contains(err, svcerr.ErrExternalAuthProviderCouldNotUpdate):
switch retErr := err.(type) {
case *errors.RequestError:
w.WriteHeader(http.StatusBadRequest)
case errors.Contains(err, svcerr.ErrAuthorization),
errors.Contains(err, svcerr.ErrDomainAuthorization),
errors.Contains(err, svcerr.ErrUnauthorizedPAT),
errors.Contains(err, svcerr.ErrSuperAdminAction):
err = unwrap(err)
w.WriteHeader(http.StatusForbidden)
case errors.Contains(err, svcerr.ErrAuthentication),
errors.Contains(err, apiutil.ErrBearerToken),
errors.Contains(err, svcerr.ErrLogin),
errors.Contains(err, apiutil.ErrUnsupportedTokenType):
err = unwrap(err)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.AuthNError:
w.WriteHeader(http.StatusUnauthorized)
case errors.Contains(err, svcerr.ErrMalformedEntity),
errors.Contains(err, apiutil.ErrMalformedPolicy),
errors.Contains(err, apiutil.ErrMissingSecret),
errors.Contains(err, errors.ErrMalformedEntity),
errors.Contains(err, apiutil.ErrMissingID),
errors.Contains(err, apiutil.ErrInvalidVerification),
errors.Contains(err, apiutil.ErrMissingName),
errors.Contains(err, apiutil.ErrMissingEmail),
errors.Contains(err, apiutil.ErrInvalidEmail),
errors.Contains(err, apiutil.ErrMissingHost),
errors.Contains(err, apiutil.ErrInvalidResetPass),
errors.Contains(err, apiutil.ErrEmptyList),
errors.Contains(err, apiutil.ErrMissingMemberKind),
errors.Contains(err, apiutil.ErrMissingMemberType),
errors.Contains(err, apiutil.ErrLimitSize),
errors.Contains(err, apiutil.ErrBearerKey),
errors.Contains(err, svcerr.ErrInvalidStatus),
errors.Contains(err, apiutil.ErrNameSize),
errors.Contains(err, apiutil.ErrInvalidIDFormat),
errors.Contains(err, apiutil.ErrInvalidQueryParams),
errors.Contains(err, apiutil.ErrMissingRelation),
errors.Contains(err, apiutil.ErrValidation),
errors.Contains(err, apiutil.ErrMissingPass),
errors.Contains(err, apiutil.ErrMissingConfPass),
errors.Contains(err, apiutil.ErrPasswordFormat),
errors.Contains(err, svcerr.ErrInvalidRole),
errors.Contains(err, svcerr.ErrInvalidPolicy),
errors.Contains(err, apiutil.ErrInvitationState),
errors.Contains(err, apiutil.ErrInvalidAPIKey),
errors.Contains(err, svcerr.ErrViewEntity),
errors.Contains(err, apiutil.ErrMissingCertData),
errors.Contains(err, apiutil.ErrInvalidContact),
errors.Contains(err, apiutil.ErrInvalidTopic),
errors.Contains(err, apiutil.ErrInvalidCertData),
errors.Contains(err, apiutil.ErrEmptyMessage),
errors.Contains(err, apiutil.ErrInvalidLevel),
errors.Contains(err, apiutil.ErrInvalidDirection),
errors.Contains(err, apiutil.ErrInvalidEntityType),
errors.Contains(err, apiutil.ErrMissingEntityType),
errors.Contains(err, apiutil.ErrInvalidTimeFormat),
errors.Contains(err, svcerr.ErrSearch),
errors.Contains(err, apiutil.ErrEmptySearchQuery),
errors.Contains(err, apiutil.ErrLenSearchQuery),
errors.Contains(err, apiutil.ErrMissingDomainID),
errors.Contains(err, apiutil.ErrMissingUserID),
errors.Contains(err, apiutil.ErrMissingPATID),
errors.Contains(err, apiutil.ErrMissingUsername),
errors.Contains(err, apiutil.ErrMissingUsernameEmail),
errors.Contains(err, apiutil.ErrMissingFirstName),
errors.Contains(err, apiutil.ErrMissingLastName),
errors.Contains(err, apiutil.ErrInvalidUsername),
errors.Contains(err, apiutil.ErrMissingIdentity),
errors.Contains(err, apiutil.ErrInvalidProfilePictureURL),
errors.Contains(err, apiutil.ErrSelfParentingNotAllowed),
errors.Contains(err, apiutil.ErrMissingChildrenGroupIDs),
errors.Contains(err, apiutil.ErrMissingParentGroupID),
errors.Contains(err, apiutil.ErrMissingConnectionType),
errors.Contains(err, apiutil.ErrMissingRoleName),
errors.Contains(err, apiutil.ErrMissingRoleID),
errors.Contains(err, apiutil.ErrMissingPolicyEntityType),
errors.Contains(err, apiutil.ErrMissingRoleMembers),
errors.Contains(err, apiutil.ErrMissingDescription),
errors.Contains(err, apiutil.ErrMissingEntityID),
errors.Contains(err, apiutil.ErrInvalidRouteFormat),
errors.Contains(err, svcerr.ErrRetainOneMember),
errors.Contains(err, apiutil.ErrMissingRoute):
err = unwrap(err)
w.WriteHeader(http.StatusBadRequest)
case errors.Contains(err, svcerr.ErrCreateEntity),
errors.Contains(err, svcerr.ErrUpdateEntity),
errors.Contains(err, svcerr.ErrRemoveEntity),
errors.Contains(err, svcerr.ErrEnableClient),
errors.Contains(err, svcerr.ErrEnableUser),
errors.Contains(err, svcerr.ErrDisableUser):
err = unwrap(err)
w.WriteHeader(http.StatusUnprocessableEntity)
case errors.Contains(err, svcerr.ErrNotFound):
err = unwrap(err)
w.WriteHeader(http.StatusNotFound)
case errors.Contains(err, errors.ErrStatusAlreadyAssigned),
errors.Contains(err, svcerr.ErrInvitationAlreadyRejected),
errors.Contains(err, svcerr.ErrInvitationAlreadyAccepted),
errors.Contains(err, svcerr.ErrConflict):
err = unwrap(err)
w.WriteHeader(http.StatusConflict)
case errors.Contains(err, apiutil.ErrUnsupportedContentType):
err = unwrap(err)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.AuthZError:
w.WriteHeader(http.StatusForbidden)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.MediaTypeError:
w.WriteHeader(http.StatusUnsupportedMediaType)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.ServiceError:
w.WriteHeader(http.StatusUnprocessableEntity)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.NotFoundError:
w.WriteHeader(http.StatusNotFound)
if err := json.NewEncoder(w).Encode(retErr); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
case *errors.InternalError:
w.WriteHeader(http.StatusInternalServerError)
return
default:
w.WriteHeader(http.StatusInternalServerError)
}
if wrapper != nil {
err = errors.Wrap(wrapper, err)
}
if errorVal, ok := err.(errors.Error); ok {
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}
}
func unwrap(err error) error {
wrapper, err := errors.Unwrap(err)
if wrapper != nil {
return wrapper
}
return err
}
+75 -92
View File
@@ -218,121 +218,104 @@ func TestEncodeResponse(t *testing.T) {
func TestEncodeError(t *testing.T) {
cases := []struct {
desc string
errs []error
code int
desc string
err error
code int
hasBody bool
checkError bool
}{
{
desc: "BadRequest",
errs: []error{
apiutil.ErrMissingSecret,
svcerr.ErrMalformedEntity,
errors.ErrMalformedEntity,
apiutil.ErrMissingID,
apiutil.ErrEmptyList,
apiutil.ErrMissingMemberType,
apiutil.ErrMissingMemberKind,
apiutil.ErrLimitSize,
apiutil.ErrNameSize,
svcerr.ErrViewEntity,
},
code: http.StatusBadRequest,
desc: "RequestError - Missing Secret",
err: apiutil.ErrMissingSecret,
code: http.StatusBadRequest,
hasBody: true,
},
{
desc: "BadRequest with validation error",
errs: []error{
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingSecret),
errors.Wrap(apiutil.ErrValidation, svcerr.ErrMalformedEntity),
errors.Wrap(apiutil.ErrValidation, errors.ErrMalformedEntity),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptyList),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingMemberType),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingMemberKind),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize),
errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize),
},
code: http.StatusBadRequest,
desc: "RequestError - Missing ID",
err: apiutil.ErrMissingID,
code: http.StatusBadRequest,
hasBody: true,
},
{
desc: "Unauthorized",
errs: []error{
svcerr.ErrAuthentication,
svcerr.ErrAuthentication,
apiutil.ErrBearerToken,
},
code: http.StatusUnauthorized,
},
{
desc: "NotFound",
errs: []error{
svcerr.ErrNotFound,
},
code: http.StatusNotFound,
desc: "RequestError - Empty List",
err: apiutil.ErrEmptyList,
code: http.StatusBadRequest,
hasBody: true,
},
{
desc: "Conflict",
errs: []error{
svcerr.ErrConflict,
svcerr.ErrConflict,
},
code: http.StatusConflict,
desc: "RequestError - Conflict",
err: svcerr.ErrConflict,
code: http.StatusBadRequest,
hasBody: true,
},
{
desc: "Forbidden",
errs: []error{
svcerr.ErrAuthorization,
svcerr.ErrAuthorization,
svcerr.ErrDomainAuthorization,
},
code: http.StatusForbidden,
desc: "NotFoundError - Not Found",
err: svcerr.ErrNotFound,
code: http.StatusNotFound,
hasBody: true,
},
{
desc: "UnsupportedMediaType",
errs: []error{
apiutil.ErrUnsupportedContentType,
},
code: http.StatusUnsupportedMediaType,
desc: "AuthNError - Authentication Failed",
err: svcerr.ErrAuthentication,
code: http.StatusUnauthorized,
hasBody: true,
},
{
desc: "StatusUnprocessableEntity",
errs: []error{
svcerr.ErrCreateEntity,
svcerr.ErrUpdateEntity,
svcerr.ErrRemoveEntity,
},
code: http.StatusUnprocessableEntity,
desc: "AuthZError - Authorization Failed",
err: svcerr.ErrAuthorization,
code: http.StatusForbidden,
hasBody: true,
},
{
desc: "InternalServerError",
errs: []error{
errors.New("test"),
},
code: http.StatusInternalServerError,
desc: "AuthZError - Domain Authorization Failed",
err: svcerr.ErrDomainAuthorization,
code: http.StatusForbidden,
hasBody: true,
},
{
desc: "MediaTypeError - Unsupported Content Type",
err: apiutil.ErrUnsupportedContentType,
code: http.StatusUnsupportedMediaType,
hasBody: true,
},
{
desc: "ServiceError - Create Entity Failed",
err: svcerr.ErrCreateEntity,
code: http.StatusUnprocessableEntity,
hasBody: true,
},
{
desc: "ServiceError - Update Entity Failed",
err: svcerr.ErrUpdateEntity,
code: http.StatusUnprocessableEntity,
hasBody: true,
},
{
desc: "ServiceError - Remove Entity Failed",
err: svcerr.ErrRemoveEntity,
code: http.StatusUnprocessableEntity,
hasBody: true,
},
{
desc: "InternalError",
err: errors.NewInternalError(),
code: http.StatusInternalServerError,
hasBody: false,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
responseWriter := newResponseWriter()
for _, err := range c.errs {
api.EncodeError(context.Background(), err, responseWriter)
assert.Equal(t, c.code, responseWriter.StatusCode())
message := body{}
jerr := json.Unmarshal(responseWriter.Body(), &message)
assert.NoError(t, jerr)
var wrapper error
switch errors.Contains(err, apiutil.ErrValidation) {
case true:
wrapper, err = errors.Unwrap(err)
assert.Equal(t, err.Error(), message.Error)
assert.Equal(t, wrapper.Error(), message.Message)
case false:
assert.Equal(t, err.Error(), message.Message)
}
api.EncodeError(context.Background(), c.err, responseWriter)
assert.Equal(t, c.code, responseWriter.StatusCode())
if !c.hasBody {
return
}
message := body{}
jerr := json.Unmarshal(responseWriter.Body(), &message)
assert.NoError(t, jerr)
assert.NotEmpty(t, message.Message)
})
}
}
+91 -88
View File
@@ -10,268 +10,271 @@ import "github.com/absmach/supermq/pkg/errors"
// errors are logged twice.
var (
// ErrValidation indicates that an error was returned by the API.
ErrValidation = errors.New("something went wrong with the request")
ErrValidation = errors.NewRequestError("something went wrong with the request")
// ErrBearerToken indicates missing or invalid bearer user token.
ErrBearerToken = errors.New("missing or invalid bearer user token")
ErrBearerToken = errors.NewAuthNError("missing or invalid bearer user token")
// ErrBearerKey indicates missing or invalid bearer entity key.
ErrBearerKey = errors.New("missing or invalid bearer entity key")
ErrBearerKey = errors.NewAuthNError("missing or invalid bearer entity key")
// ErrMissingID indicates missing entity ID.
ErrMissingID = errors.New("missing entity id")
ErrMissingID = errors.NewRequestError("missing entity id")
// ErrMissingEntityID indicates missing entity ID.
ErrMissingEntityID = errors.New("missing entity id")
ErrMissingEntityID = errors.NewRequestError("missing entity id")
// ErrMissingClientID indicates missing client ID.
ErrMissingClientID = errors.New("missing cient id")
ErrMissingClientID = errors.NewRequestError("missing client id")
// ErrMissingChannelID indicates missing client ID.
ErrMissingChannelID = errors.New("missing channel id")
ErrMissingChannelID = errors.NewRequestError("missing channel id")
// ErrMissingConnectionType indicates missing connection tpye.
ErrMissingConnectionType = errors.New("missing connection type")
ErrMissingConnectionType = errors.NewRequestError("missing connection type")
// ErrMissingParentGroupID indicates missing parent group ID.
ErrMissingParentGroupID = errors.New("missing parent group id")
ErrMissingParentGroupID = errors.NewRequestError("missing parent group id")
// ErrMissingChildrenGroupIDs indicates missing children group IDs.
ErrMissingChildrenGroupIDs = errors.New("missing children group ids")
ErrMissingChildrenGroupIDs = errors.NewRequestError("missing children group ids")
// ErrSelfParentingNotAllowed indicates child id is same as parent id.
ErrSelfParentingNotAllowed = errors.New("self parenting not allowed")
ErrSelfParentingNotAllowed = errors.NewRequestError("self parenting not allowed")
// ErrInvalidChildGroupID indicates invalid child group ID.
ErrInvalidChildGroupID = errors.New("invalid child group id")
ErrInvalidChildGroupID = errors.NewRequestError("invalid child group id")
// ErrInvalidAuthKey indicates invalid auth key.
ErrInvalidAuthKey = errors.New("invalid auth key")
// ErrInvalidIDFormat indicates an invalid ID format.
ErrInvalidIDFormat = errors.New("invalid id format provided")
ErrInvalidIDFormat = errors.NewRequestError("invalid id format provided")
// ErrNameSize indicates that name size exceeds the max.
ErrNameSize = errors.New("invalid name size")
ErrNameSize = errors.NewRequestError("invalid name size")
// ErrEmailSize indicates that email size exceeds the max.
ErrEmailSize = errors.New("invalid email size")
ErrEmailSize = errors.NewRequestError("invalid email size")
// ErrInvalidRole indicates that an invalid role.
ErrInvalidRole = errors.New("invalid client role")
ErrInvalidRole = errors.NewRequestError("invalid client role")
// ErrLimitSize indicates that an invalid limit.
ErrLimitSize = errors.New("invalid limit size")
ErrLimitSize = errors.NewRequestError("invalid limit size")
// ErrLevel indicates that an invalid level.
ErrLevel = errors.New("invalid level")
ErrLevel = errors.NewRequestError("invalid level")
// ErrOffsetSize indicates an invalid offset.
ErrOffsetSize = errors.New("invalid offset size")
ErrOffsetSize = errors.NewRequestError("invalid offset size")
// ErrInvalidOrder indicates an invalid list order.
ErrInvalidOrder = errors.New("invalid list order provided")
ErrInvalidOrder = errors.NewRequestError("invalid list order provided")
// ErrInvalidDirection indicates an invalid list direction.
ErrInvalidDirection = errors.New("invalid list direction provided")
ErrInvalidDirection = errors.NewRequestError("invalid list direction provided")
// ErrInvalidMemberKind indicates an invalid member kind.
ErrInvalidMemberKind = errors.New("invalid member kind")
ErrInvalidMemberKind = errors.NewRequestError("invalid member kind")
// ErrEmptyList indicates that entity data is empty.
ErrEmptyList = errors.New("empty list provided")
ErrEmptyList = errors.NewRequestError("empty list provided")
// ErrMissingRoleName indicates that role name is empty.
ErrMissingRoleName = errors.New("empty role name")
ErrMissingRoleName = errors.NewRequestError("empty role name")
// ErrMissingRoleID indicates that role id is empty.
ErrMissingRoleID = errors.New("empty role id")
ErrMissingRoleID = errors.NewRequestError("empty role id")
// ErrMissingRoleOperations indicates that role operations are empty.
ErrMissingRoleOperations = errors.New("empty role operations")
ErrMissingRoleOperations = errors.NewRequestError("empty role operations")
// ErrMissingRoleMembers indicates that role members are empty.
ErrMissingRoleMembers = errors.New("empty role members")
ErrMissingRoleMembers = errors.NewRequestError("empty role members")
// ErrMalformedPolicy indicates that policies are malformed.
ErrMalformedPolicy = errors.New("malformed policy")
ErrMalformedPolicy = errors.NewRequestError("malformed policy")
// ErrMissingPolicySub indicates that policies are subject.
ErrMissingPolicySub = errors.New("malformed policy subject")
ErrMissingPolicySub = errors.NewRequestError("malformed policy subject")
// ErrMissingPolicyObj indicates missing policies object.
ErrMissingPolicyObj = errors.New("malformed policy object")
ErrMissingPolicyObj = errors.NewRequestError("malformed policy object")
// ErrMalformedPolicyAct indicates missing policies action.
ErrMalformedPolicyAct = errors.New("malformed policy action")
ErrMalformedPolicyAct = errors.NewRequestError("malformed policy action")
// ErrMissingPolicyEntityType indicates missing policies entity type.
ErrMissingPolicyEntityType = errors.New("missing policy entity type")
ErrMissingPolicyEntityType = errors.NewRequestError("missing policy entity type")
// ErrMalformedPolicyPer indicates missing policies relation.
ErrMalformedPolicyPer = errors.New("malformed policy permission")
ErrMalformedPolicyPer = errors.NewRequestError("malformed policy permission")
// ErrMissingCertData indicates missing cert data (ttl).
ErrMissingCertData = errors.New("missing certificate data")
ErrMissingCertData = errors.NewRequestError("missing certificate data")
// ErrInvalidCertData indicates invalid cert data (ttl).
ErrInvalidCertData = errors.New("invalid certificate data")
ErrInvalidCertData = errors.NewRequestError("invalid certificate data")
// ErrInvalidTopic indicates an invalid subscription topic.
ErrInvalidTopic = errors.New("invalid Subscription topic")
ErrInvalidTopic = errors.NewRequestError("invalid Subscription topic")
// ErrInvalidContact indicates an invalid subscription contract.
ErrInvalidContact = errors.New("invalid Subscription contact")
ErrInvalidContact = errors.NewRequestError("invalid Subscription contact")
// ErrMissingEmail indicates missing email.
ErrMissingEmail = errors.New("missing email")
ErrMissingEmail = errors.NewRequestError("missing email")
// ErrInvalidEmail indicates missing email.
ErrInvalidEmail = errors.New("invalid email")
ErrInvalidEmail = errors.NewRequestError("invalid email")
// ErrMissingHost indicates missing host.
ErrMissingHost = errors.New("missing host")
ErrMissingHost = errors.NewRequestError("missing host")
// ErrMissingPass indicates missing password.
ErrMissingPass = errors.New("missing password")
ErrMissingPass = errors.NewRequestError("missing password")
// ErrMissingConfPass indicates missing conf password.
ErrMissingConfPass = errors.New("missing conf password")
ErrMissingConfPass = errors.NewRequestError("missing conf password")
// ErrInvalidResetPass indicates an invalid reset password.
ErrInvalidResetPass = errors.New("invalid reset password")
ErrInvalidResetPass = errors.NewRequestError("invalid reset password")
// ErrInvalidComparator indicates an invalid comparator.
ErrInvalidComparator = errors.New("invalid comparator")
ErrInvalidComparator = errors.NewRequestError("invalid comparator")
// ErrMissingMemberIDs indicates missing member ids.
ErrMissingMemberIDs = errors.New("missing member ids")
ErrMissingMemberIDs = errors.NewRequestError("missing member ids")
// ErrMissingMemberType indicates missing group member type.
ErrMissingMemberType = errors.New("missing group member type")
ErrMissingMemberType = errors.NewRequestError("missing group member type")
// ErrMissingMemberKind indicates missing group member kind.
ErrMissingMemberKind = errors.New("missing group member kind")
ErrMissingMemberKind = errors.NewRequestError("missing group member kind")
// ErrMissingRelation indicates missing relation.
ErrMissingRelation = errors.New("missing relation")
ErrMissingRelation = errors.NewRequestError("missing relation")
// ErrInvalidRelation indicates an invalid relation.
ErrInvalidRelation = errors.New("invalid relation")
ErrInvalidRelation = errors.NewRequestError("invalid relation")
// ErrInvalidAPIKey indicates an invalid API key type.
ErrInvalidAPIKey = errors.New("invalid api key type")
ErrInvalidAPIKey = errors.NewRequestError("invalid api key type")
// ErrInvitationState indicates an invalid invitation state.
ErrInvitationState = errors.New("invalid invitation state")
ErrInvitationState = errors.NewRequestError("invalid invitation state")
// ErrMissingIdentity indicates missing entity Identity.
ErrMissingIdentity = errors.New("missing entity identity")
ErrMissingIdentity = errors.NewRequestError("missing entity identity")
// ErrMissingSecret indicates missing secret.
ErrMissingSecret = errors.New("missing secret")
ErrMissingSecret = errors.NewRequestError("missing secret")
// ErrPasswordFormat indicates weak password.
ErrPasswordFormat = errors.New("password does not meet the requirements")
ErrPasswordFormat = errors.NewRequestError("password does not meet the requirements")
// ErrMissingName indicates missing identity name.
ErrMissingName = errors.New("missing identity name")
ErrMissingName = errors.NewRequestError("missing identity name")
// ErrMissingRoute indicates missing route.
ErrMissingRoute = errors.New("missing route")
ErrMissingRoute = errors.NewRequestError("missing route")
// ErrInvalidLevel indicates an invalid group level.
ErrInvalidLevel = errors.New("invalid group level (should be between 0 and 5)")
ErrInvalidLevel = errors.NewRequestError("invalid group level (should be between 0 and 5)")
// ErrNotFoundParam indicates that the parameter was not found in the query.
ErrNotFoundParam = errors.New("parameter not found in the query")
ErrNotFoundParam = errors.NewRequestError("parameter not found in the query")
// ErrInvalidQueryParams indicates invalid query parameters.
ErrInvalidQueryParams = errors.New("invalid query parameters")
ErrInvalidQueryParams = errors.NewRequestError("invalid query parameters")
// ErrInvalidVisibilityType indicates invalid visibility type.
ErrInvalidVisibilityType = errors.New("invalid visibility type")
ErrInvalidVisibilityType = errors.NewRequestError("invalid visibility type")
// ErrUnsupportedContentType indicates unacceptable or lack of Content-Type.
ErrUnsupportedContentType = errors.New("unsupported content type")
ErrUnsupportedContentType = errors.NewMediaTypeError("unsupported content type")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = errors.New("failed to rollback transaction")
ErrRollbackTx = errors.NewRequestError("failed to rollback transaction")
// ErrInvalidAggregation indicates invalid aggregation value.
ErrInvalidAggregation = errors.New("invalid aggregation value")
ErrInvalidAggregation = errors.NewRequestError("invalid aggregation value")
// ErrInvalidInterval indicates invalid interval value.
ErrInvalidInterval = errors.New("invalid interval value")
ErrInvalidInterval = errors.NewRequestError("invalid interval value")
// ErrMissingFrom indicates missing from value.
ErrMissingFrom = errors.New("missing from time value")
ErrMissingFrom = errors.NewRequestError("missing from time value")
// ErrMissingTo indicates missing to value.
ErrMissingTo = errors.New("missing to time value")
ErrMissingTo = errors.NewRequestError("missing to time value")
// ErrEmptyMessage indicates empty message.
ErrEmptyMessage = errors.New("empty message")
ErrEmptyMessage = errors.NewRequestError("empty message")
// ErrMissingEntityType indicates missing entity type.
ErrMissingEntityType = errors.New("missing entity type")
ErrMissingEntityType = errors.NewRequestError("missing entity type")
// ErrInvalidEntityType indicates invalid entity type.
ErrInvalidEntityType = errors.New("invalid entity type")
ErrInvalidEntityType = errors.NewRequestError("invalid entity type")
// ErrInvalidTimeFormat indicates invalid time format i.e not unix time.
ErrInvalidTimeFormat = errors.New("invalid time format use unix time")
ErrInvalidTimeFormat = errors.NewRequestError("invalid time format use unix time")
// ErrEmptySearchQuery indicates search query should not be empty.
ErrEmptySearchQuery = errors.New("search query must not be empty")
ErrEmptySearchQuery = errors.NewRequestError("search query must not be empty")
// ErrLenSearchQuery indicates search query length.
ErrLenSearchQuery = errors.New("search query must be at least 3 characters")
ErrLenSearchQuery = errors.NewRequestError("search query must be at least 3 characters")
// ErrMissingDomainID indicates missing domainID.
ErrMissingDomainID = errors.New("missing domainID")
ErrMissingDomainID = errors.NewRequestError("missing domainID")
// ErrMissingUsername indicates missing user name.
ErrMissingUsername = errors.New("missing username")
ErrMissingUsername = errors.NewRequestError("missing username")
// ErrInvalidUsername indicates invalid user name.
ErrInvalidUsername = errors.New("invalid username")
ErrInvalidUsername = errors.NewRequestError("invalid username")
// ErrMissingFirstName indicates missing first name.
ErrMissingFirstName = errors.New("missing first name")
ErrMissingFirstName = errors.NewRequestError("missing first name")
// ErrMissingLastName indicates missing last name.
ErrMissingLastName = errors.New("missing last name")
ErrMissingLastName = errors.NewRequestError("missing last name")
// ErrInvalidProfilePictureURL indicates that the profile picture url is invalid.
ErrInvalidProfilePictureURL = errors.New("invalid profile picture url")
ErrInvalidProfilePictureURL = errors.NewRequestError("invalid profile picture url")
ErrMultipleEntitiesFilter = errors.New("multiple entities are provided in filter are not supported")
ErrMultipleEntitiesFilter = errors.NewRequestError("multiple entities are provided in filter are not supported")
// ErrMissingDescription indicates missing description.
ErrMissingDescription = errors.New("missing description")
ErrMissingDescription = errors.NewRequestError("missing description")
// ErrUnsupportedTokenType indicates that this type of token is not supported.
ErrUnsupportedTokenType = errors.New("unsupported content token type")
ErrUnsupportedTokenType = errors.NewRequestError("unsupported content token type")
// ErrMissingUserID indicates missing user ID.
ErrMissingUserID = errors.New("missing user id")
ErrMissingUserID = errors.NewRequestError("missing user id")
// ErrMissingPATID indicates missing pat ID.
ErrMissingPATID = errors.New("missing pat id")
ErrMissingPATID = errors.NewRequestError("missing pat id")
// ErrInvalidNameFormat indicates invalid name format.
ErrInvalidNameFormat = errors.New("invalid name format")
ErrInvalidNameFormat = errors.NewRequestError("invalid name format")
// ErrInvalidRouteFormat indicates invalid route format.
ErrInvalidRouteFormat = errors.New("invalid route format")
ErrInvalidRouteFormat = errors.NewRequestError("invalid route format")
// ErrMissingUsernameEmail indicates missing user name / email.
ErrMissingUsernameEmail = errors.New("missing username / email")
ErrMissingUsernameEmail = errors.NewRequestError("missing username / email")
// ErrInvalidVerification indicates invalid email verification.
ErrInvalidVerification = errors.New("invalid verification")
ErrInvalidVerification = errors.NewRequestError("invalid verification")
// ErrEmailNotVerified indicates invalid email not verified.
ErrEmailNotVerified = errors.New("email not verified")
ErrEmailNotVerified = errors.NewRequestError("email not verified")
// ErrMalformedRequest indicates malformed request body.
ErrMalformedRequestBody = errors.NewRequestError("request body is not a valid JSON, expecting a valid JSON")
)
+1 -1
View File
@@ -253,7 +253,7 @@ func TestRetrieve(t *testing.T) {
desc: "retrieve a non-existing key",
id: "non-existing",
token: token.AccessToken,
status: http.StatusBadRequest,
status: http.StatusNotFound,
err: svcerr.ErrNotFound,
},
{
+1 -1
View File
@@ -57,7 +57,7 @@ func decodeIssue(_ context.Context, r *http.Request) (any, error) {
req := issueKeyReq{token: apiutil.ExtractBearerToken(r)}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
+6 -6
View File
@@ -140,7 +140,7 @@ func decodeCreatePATRequest(_ context.Context, r *http.Request) (any, error) {
}
req := createPatReq{token: token}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -171,7 +171,7 @@ func decodeUpdatePATNameRequest(_ context.Context, r *http.Request) (any, error)
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -190,7 +190,7 @@ func decodeUpdatePATDescriptionRequest(_ context.Context, r *http.Request) (any,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -262,7 +262,7 @@ func decodeResetPATSecretRequest(_ context.Context, r *http.Request) (any, error
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -305,7 +305,7 @@ func decodeAddScopeRequest(_ context.Context, r *http.Request) (any, error) {
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -348,7 +348,7 @@ func decodeRemoveScopeRequest(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
+3 -6
View File
@@ -26,10 +26,7 @@ const (
secret = "test"
)
var (
errInvalidIssuer = errors.New("invalid token issuer value")
reposecret = []byte("test")
)
var reposecret = []byte("test")
func newToken(issuerName string, key auth.Key) string {
builder := jwt.NewBuilder()
@@ -194,7 +191,7 @@ func TestParse(t *testing.T) {
desc: "parse token with invalid issuer",
key: auth.Key{},
token: inValidToken,
err: errInvalidIssuer,
err: svcerr.ErrAuthentication,
},
{
desc: "parse token with invalid content",
@@ -212,7 +209,7 @@ func TestParse(t *testing.T) {
desc: "parse token with empty type",
key: emptyTypeKey,
token: emptyTypeToken,
err: errors.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
}
+5 -5
View File
@@ -40,11 +40,11 @@ var (
errMalformedPAT = errors.New("malformed personal access token")
errFailedToParseUUID = errors.New("failed to parse string to UUID")
errInvalidLenFor2UUIDs = errors.New("invalid input length for 2 UUID, excepted 32 byte")
errRevokedPAT = errors.New("revoked pat")
errCreatePAT = errors.New("failed to create PAT")
errUpdatePAT = errors.New("failed to update PAT")
errRetrievePAT = errors.New("failed to retrieve PAT")
errDeletePAT = errors.New("failed to delete PAT")
errRevokedPAT = errors.NewServiceError("revoked pat")
errCreatePAT = errors.NewServiceError("failed to create PAT")
errUpdatePAT = errors.NewServiceError("failed to update PAT")
errRetrievePAT = errors.NewServiceError("failed to retrieve PAT")
errDeletePAT = errors.NewServiceError("failed to delete PAT")
errInvalidScope = errors.New("invalid scope")
)
+9 -9
View File
@@ -38,7 +38,7 @@ func decodeCreateChannelReq(_ context.Context, r *http.Request) (any, error) {
req := createChannelReq{}
if err := json.NewDecoder(r.Body).Decode(&req.Channel); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -51,7 +51,7 @@ func decodeCreateChannelsReq(_ context.Context, r *http.Request) (any, error) {
req := createChannelsReq{}
if err := json.NewDecoder(r.Body).Decode(&req.Channels); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -188,7 +188,7 @@ func decodeUpdateChannel(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -203,7 +203,7 @@ func decodeUpdateChannelTags(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -218,7 +218,7 @@ func decodeSetChannelParentGroupStatus(_ context.Context, r *http.Request) (any,
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -254,7 +254,7 @@ func decodeConnectChannelClientRequest(_ context.Context, r *http.Request) (any,
channelID: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -268,7 +268,7 @@ func decodeDisconnectChannelClientsRequest(_ context.Context, r *http.Request) (
channelID: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -281,7 +281,7 @@ func decodeConnectRequest(_ context.Context, r *http.Request) (any, error) {
req := connectRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -294,7 +294,7 @@ func decodeDisconnectRequest(_ context.Context, r *http.Request) (any, error) {
req := disconnectRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
+6 -6
View File
@@ -574,7 +574,7 @@ func TestListChannels(t *testing.T) {
token: validToken,
query: "offset=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list channels with limit",
@@ -596,7 +596,7 @@ func TestListChannels(t *testing.T) {
token: validToken,
query: "limit=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list channels with limit greater than max",
@@ -604,7 +604,7 @@ func TestListChannels(t *testing.T) {
domainID: validID,
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrLimitSize,
},
{
desc: "list channels with name",
@@ -656,7 +656,7 @@ func TestListChannels(t *testing.T) {
token: validToken,
query: "status=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: svcerr.ErrInvalidStatus,
},
{
desc: "list channels with duplicate status",
@@ -716,7 +716,7 @@ func TestListChannels(t *testing.T) {
token: validToken,
query: "metadata=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list channels with duplicate metadata",
@@ -1114,7 +1114,7 @@ func TestUpdateChannelTagsEndpoint(t *testing.T) {
contentType: contentType,
data: fmt.Sprintf(`{"tags":["%s"}`, newTag),
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update channel with empty id",
+38 -46
View File
@@ -22,7 +22,6 @@ import (
"github.com/absmach/supermq/pkg/roles"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v5/pgconn"
"github.com/lib/pq"
)
@@ -30,14 +29,13 @@ const (
rolesTableNamePrefix = "channels"
entityTableName = "channels"
entityIDColumnName = "id"
pgDuplicateErrCode = "23505"
)
var _ channels.Repository = (*channelRepository)(nil)
type channelRepository struct {
db postgres.Database
eh errors.Handler
rolesPostgres.Repository
}
@@ -45,8 +43,12 @@ type channelRepository struct {
// repository.
func NewRepository(db postgres.Database) channels.Repository {
rolesRepo := rolesPostgres.NewRepository(db, policies.ChannelType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &channelRepository{
db: db,
eh: postgres.NewErrorHandler(errHandlerOptions...),
Repository: rolesRepo,
}
}
@@ -67,7 +69,7 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
row, err := cr.db.NamedQueryContext(ctx, q, dbchs)
if err != nil {
return []channels.Channel{}, handleSaveError(repoerr.ErrCreateEntity, err)
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
@@ -77,7 +79,7 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
for row.Next() {
dbch := dbChannel{}
if err := row.StructScan(&dbch); err != nil {
return []channels.Channel{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrFailedOpDB, err)
}
ch, err := toChannel(dbch)
@@ -89,16 +91,6 @@ func (cr *channelRepository) Save(ctx context.Context, chs ...channels.Channel)
return reChs, nil
}
func handleSaveError(wrapper, err error) error {
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
switch pqErr.ConstraintName {
case "unique_domain_route_not_null":
return errors.ErrRouteNotAvailable
}
}
return postgres.HandleError(wrapper, err)
}
func (cr *channelRepository) Update(ctx context.Context, channel channels.Channel) (channels.Channel, error) {
var query []string
var upq string
@@ -144,14 +136,14 @@ func (cr *channelRepository) RetrieveByID(ctx context.Context, id string) (chann
row, err := cr.db.NamedQueryContext(ctx, q, dbch)
if err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbch = dbChannel{}
if row.Next() {
if err := row.StructScan(&dbch); err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toChannel(dbch)
}
@@ -170,14 +162,14 @@ func (cr *channelRepository) RetrieveByRoute(ctx context.Context, route, domainI
row, err := cr.db.NamedQueryContext(ctx, q, dbch)
if err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbch = dbChannel{}
if row.Next() {
if err := row.StructScan(&dbch); err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toChannel(dbch)
}
@@ -368,17 +360,17 @@ func (cr *channelRepository) RetrieveByIDWithRoles(ctx context.Context, id, memb
}
row, err := cr.db.NamedQueryContext(ctx, query, parameters)
if err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbch := dbChannel{}
if !row.Next() {
return channels.Channel{}, errors.Wrap(repoerr.ErrNotFound, err)
return channels.Channel{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbch); err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toChannel(dbch)
@@ -450,14 +442,14 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
if !pm.OnlyTotal {
rows, err := cr.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
for rows.Next() {
dbch := dbChannel{}
if err := rows.StructScan(&dbch); err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
ch, err := toChannel(dbch)
@@ -476,7 +468,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.Page)
total, err := postgres.Total(ctx, cr.db, cq, dbPage)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := channels.ChannelsPage{
@@ -566,14 +558,14 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
for rows.Next() {
dbc := dbChannel{}
if err := rows.StructScan(&dbc); err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := toChannel(dbc)
@@ -617,7 +609,7 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return channels.ChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return channels.ChannelsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := channels.ChannelsPage{
@@ -912,7 +904,7 @@ func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error {
}
result, err := cr.db.NamedExecContext(ctx, q, params)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -928,7 +920,7 @@ func (cr *channelRepository) SetParentGroup(ctx context.Context, ch channels.Cha
}
result, err := cr.db.NamedExecContext(ctx, q, dbCh)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -944,7 +936,7 @@ func (cr *channelRepository) RemoveParentGroup(ctx context.Context, ch channels.
}
result, err := cr.db.NamedExecContext(ctx, q, dbCh)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -958,7 +950,7 @@ func (cr *channelRepository) AddConnections(ctx context.Context, conns []channel
VALUES (:channel_id, :domain_id, :client_id, :type );`
if _, err := cr.db.NamedExecContext(ctx, q, dbConns); err != nil {
return postgres.HandleError(repoerr.ErrCreateEntity, err)
return cr.eh.HandleError(repoerr.ErrCreateEntity, err)
}
return nil
@@ -967,7 +959,7 @@ func (cr *channelRepository) AddConnections(ctx context.Context, conns []channel
func (cr *channelRepository) RemoveConnections(ctx context.Context, conns []channels.Connection) (retErr error) {
tx, err := cr.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
defer func() {
if retErr != nil {
@@ -985,11 +977,11 @@ func (cr *channelRepository) RemoveConnections(ctx context.Context, conns []chan
}
dbConn := toDBConnection(conn)
if _, err := tx.NamedExec(query, dbConn); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, errors.Wrap(fmt.Errorf("failed to delete connection for channel_id: %s, domain_id: %s client_id %s", conn.ChannelID, conn.DomainID, conn.ClientID), err))
return cr.eh.HandleError(repoerr.ErrRemoveEntity, errors.Wrap(fmt.Errorf("failed to delete connection for channel_id: %s, domain_id: %s client_id %s", conn.ChannelID, conn.DomainID, conn.ClientID), err))
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
return nil
}
@@ -999,7 +991,7 @@ func (cr *channelRepository) CheckConnection(ctx context.Context, conn channels.
dbConn := toDBConnection(conn)
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
if err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
return cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -1014,7 +1006,7 @@ func (cr *channelRepository) ClientAuthorize(ctx context.Context, conn channels.
dbConn := toDBConnection(conn)
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
if err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
return cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -1030,7 +1022,7 @@ func (cr *channelRepository) ChannelConnectionsCount(ctx context.Context, id str
total, err := postgres.Total(ctx, cr.db, query, dbConn)
if err != nil {
return 0, postgres.HandleError(repoerr.ErrViewEntity, err)
return 0, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
return total, nil
}
@@ -1041,7 +1033,7 @@ func (cr *channelRepository) DoesChannelHaveConnections(ctx context.Context, id
rows, err := cr.db.NamedQueryContext(ctx, query, dbConn)
if err != nil {
return false, postgres.HandleError(repoerr.ErrViewEntity, err)
return false, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -1053,7 +1045,7 @@ func (cr *channelRepository) RemoveClientConnections(ctx context.Context, client
dbConn := dbConnection{ClientID: clientID}
if _, err := cr.db.NamedExecContext(ctx, query, dbConn); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
return nil
}
@@ -1063,7 +1055,7 @@ func (cr *channelRepository) RemoveChannelConnections(ctx context.Context, chann
dbConn := dbConnection{ChannelID: channelID}
if _, err := cr.db.NamedExecContext(ctx, query, dbConn); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
return nil
}
@@ -1074,7 +1066,7 @@ func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, pa
rows, err := cr.db.NamedQueryContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)})
if err != nil {
return []channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -1082,7 +1074,7 @@ func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, pa
for rows.Next() {
dbch := dbChannel{}
if err := rows.StructScan(&dbch); err != nil {
return []channels.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err)
return []channels.Channel{}, cr.eh.HandleError(repoerr.ErrViewEntity, err)
}
ch, err := toChannel(dbch)
@@ -1099,7 +1091,7 @@ func (cr *channelRepository) UnsetParentGroupFromChannels(ctx context.Context, p
query := "UPDATE channels SET parent_group_id = NULL WHERE parent_group_id = :parent_group_id"
if _, err := cr.db.NamedExecContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)}); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
return cr.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
return nil
}
@@ -1112,14 +1104,14 @@ func (cr *channelRepository) update(ctx context.Context, ch channels.Channel, qu
row, err := cr.db.NamedQueryContext(ctx, query, dbch)
if err != nil {
return channels.Channel{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
dbch = dbChannel{}
if row.Next() {
if err := row.StructScan(&dbch); err != nil {
return channels.Channel{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
return channels.Channel{}, cr.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return toChannel(dbch)
+10 -9
View File
@@ -102,6 +102,7 @@ var (
"subgroup_set_parent",
"subgroup_update",
}
errChannelExists = errors.New("channel id already exists")
)
func TestSave(t *testing.T) {
@@ -143,7 +144,7 @@ func TestSave(t *testing.T) {
desc: "add duplicate channel",
channel: validChannel,
resp: []channels.Channel{},
err: repoerr.ErrConflict,
err: errChannelExists,
},
{
desc: "add channel with invalid ID",
@@ -156,7 +157,7 @@ func TestSave(t *testing.T) {
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add channel with invalid domain",
@@ -169,7 +170,7 @@ func TestSave(t *testing.T) {
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add channel with invalid name",
@@ -182,7 +183,7 @@ func TestSave(t *testing.T) {
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add channel with invalid metadata",
@@ -197,7 +198,7 @@ func TestSave(t *testing.T) {
Status: channels.EnabledStatus,
},
resp: []channels.Channel{},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add channel with duplicate name",
@@ -1086,7 +1087,7 @@ func TestSetParentGroup(t *testing.T) {
desc: "set parent group with invalid parent group ID",
id: validChannel.ID,
parentGroupID: invalidID,
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrUpdateEntity,
},
}
@@ -1208,7 +1209,7 @@ func TestAddConnection(t *testing.T) {
DomainID: testsutil.GenerateUUID(t),
Type: connections.Publish,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add connection with invalid channel ID",
@@ -1218,7 +1219,7 @@ func TestAddConnection(t *testing.T) {
DomainID: testsutil.GenerateUUID(t),
Type: connections.Publish,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add connection with invalid domain ID",
@@ -1228,7 +1229,7 @@ func TestAddConnection(t *testing.T) {
DomainID: invalidID,
Type: connections.Publish,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
}
+26
View File
@@ -0,0 +1,26 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "unique_domain_route_not_null":
return errors.ErrRouteNotAvailable, true
case "channels_pkey":
return errors.NewRequestError("channel id already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+1 -1
View File
@@ -501,7 +501,7 @@ func (svc service) changeChannelStatus(ctx context.Context, userID string, chann
return Channel{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if dbchannel.Status == channel.Status {
return Channel{}, errors.ErrStatusAlreadyAssigned
return Channel{}, svcerr.ErrStatusAlreadyAssigned
}
channel.UpdatedBy = userID
+6 -7
View File
@@ -63,10 +63,9 @@ var (
},
},
}
parentGroupID = testsutil.GenerateUUID(&testing.T{})
validID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
errRollbackRoles = errors.New("failed to rollback roles")
parentGroupID = testsutil.GenerateUUID(&testing.T{})
validID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
)
var (
@@ -203,7 +202,7 @@ func TestCreateChannel(t *testing.T) {
},
addRoleErr: svcerr.ErrCreateEntity,
deletePoliciesErr: svcerr.ErrRemoveEntity,
err: errRollbackRoles,
err: svcerr.ErrRemoveEntity,
},
}
@@ -404,7 +403,7 @@ func TestEnableChannel(t *testing.T) {
retrieveResp: channels.Channel{
Status: channels.EnabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable channel with retrieve error",
@@ -467,7 +466,7 @@ func TestDisableChannel(t *testing.T) {
retrieveResp: channels.Channel{
Status: channels.DisabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable channel with retrieve error",
+6 -6
View File
@@ -170,7 +170,7 @@ func decodeUpdateClient(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, clientID),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -185,7 +185,7 @@ func decodeUpdateClientTags(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, clientID),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -200,7 +200,7 @@ func decodeUpdateClientCredentials(_ context.Context, r *http.Request) (any, err
id: chi.URLParam(r, clientID),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -213,7 +213,7 @@ func decodeCreateClientReq(_ context.Context, r *http.Request) (any, error) {
var c clients.Client
if err := json.NewDecoder(r.Body).Decode(&c); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
req := createClientReq{
client: c,
@@ -229,7 +229,7 @@ func decodeCreateClientsReq(_ context.Context, r *http.Request) (any, error) {
c := createClientsReq{}
if err := json.NewDecoder(r.Body).Decode(&c.Clients); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return c, nil
@@ -252,7 +252,7 @@ func decodeSetClientParentGroupStatus(_ context.Context, r *http.Request) (any,
id: chi.URLParam(r, clientID),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
+30 -45
View File
@@ -134,7 +134,7 @@ func TestCreateClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusConflict,
status: http.StatusBadRequest,
err: svcerr.ErrConflict,
},
{
@@ -161,7 +161,7 @@ func TestCreateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidIDFormat,
},
{
desc: "register a client that can't be marshalled",
@@ -179,7 +179,7 @@ func TestCreateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "register client with invalid status",
@@ -212,7 +212,7 @@ func TestCreateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
}
@@ -316,7 +316,7 @@ func TestCreateClients(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrEmptyList,
len: 0,
},
{
@@ -337,7 +337,7 @@ func TestCreateClients(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidIDFormat,
},
{
desc: "create clients with invalid contentype",
@@ -351,7 +351,7 @@ func TestCreateClients(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "create a client that can't be marshalled",
@@ -372,7 +372,7 @@ func TestCreateClients(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "create clients with service error",
@@ -499,7 +499,7 @@ func TestListClients(t *testing.T) {
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
query: "offset=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list clients with limit",
@@ -524,7 +524,7 @@ func TestListClients(t *testing.T) {
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
query: "limit=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list clients with limit greater than max",
@@ -533,7 +533,7 @@ func TestListClients(t *testing.T) {
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrLimitSize,
},
{
desc: "list clients with name",
@@ -590,7 +590,7 @@ func TestListClients(t *testing.T) {
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
query: "status=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: svcerr.ErrInvalidStatus,
},
{
desc: "list clients with duplicate status",
@@ -656,7 +656,7 @@ func TestListClients(t *testing.T) {
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID, SuperAdmin: false},
query: "metadata=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list clients with duplicate metadata",
@@ -913,8 +913,7 @@ func TestUpdateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update client with malformed data",
@@ -925,8 +924,7 @@ func TestUpdateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update client with empty id",
@@ -937,8 +935,7 @@ func TestUpdateClient(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrMissingID,
err: apiutil.ErrMissingID,
},
{
desc: "update client with name that is too long",
@@ -1020,8 +1017,7 @@ func TestUpdateClientsTags(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusOK,
err: nil,
err: nil,
},
{
desc: "update client tags with empty token",
@@ -1053,8 +1049,7 @@ func TestUpdateClientsTags(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "update client tags with invalid contentype",
@@ -1065,7 +1060,7 @@ func TestUpdateClientsTags(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update clients tags with empty id",
@@ -1076,8 +1071,7 @@ func TestUpdateClientsTags(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingID,
},
{
desc: "update clients with malfomed data",
@@ -1088,8 +1082,7 @@ func TestUpdateClientsTags(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
}
@@ -1203,7 +1196,7 @@ func TestUpdateClientSecret(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingID,
},
{
desc: "update client secret with empty secret",
@@ -1220,8 +1213,7 @@ func TestUpdateClientSecret(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingSecret,
},
{
desc: "update client secret with invalid contentype",
@@ -1238,8 +1230,7 @@ func TestUpdateClientSecret(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update client secret with malformed data",
@@ -1256,8 +1247,7 @@ func TestUpdateClientSecret(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
}
@@ -1316,8 +1306,7 @@ func TestEnableClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusOK,
err: nil,
err: nil,
},
{
desc: "enable client with invalid token",
@@ -1337,8 +1326,7 @@ func TestEnableClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingID,
},
}
@@ -1401,8 +1389,7 @@ func TestDisableClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusOK,
err: nil,
err: nil,
},
{
desc: "disable client with invalid token",
@@ -1422,8 +1409,7 @@ func TestDisableClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingID,
},
}
@@ -1481,8 +1467,7 @@ func TestDeleteClient(t *testing.T) {
token: validToken,
authnRes: smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID},
status: http.StatusNoContent,
err: nil,
err: nil,
},
{
desc: "delete client with invalid token",
+26 -22
View File
@@ -36,6 +36,7 @@ var _ clients.Repository = (*clientRepo)(nil)
type clientRepo struct {
DB postgres.Database
eh errors.Handler
rolesPostgres.Repository
}
@@ -43,9 +44,12 @@ type clientRepo struct {
// implementation of Clients repository.
func NewRepository(db postgres.Database) clients.Repository {
repo := rolesPostgres.NewRepository(db, policies.ClientType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &clientRepo{
DB: db,
eh: postgres.NewErrorHandler(errHandlerOptions...),
Repository: repo,
}
}
@@ -66,7 +70,7 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
row, err := repo.DB.NamedQueryContext(ctx, q, dbClients)
if err != nil {
return []clients.Client{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
return []clients.Client{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
@@ -75,7 +79,7 @@ func (repo *clientRepo) Save(ctx context.Context, cls ...clients.Client) ([]clie
for row.Next() {
dbcli := DBClient{}
if err := row.StructScan(&dbcli); err != nil {
return []clients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return []clients.Client{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
}
client, err := ToClient(dbcli)
@@ -108,14 +112,14 @@ func (repo *clientRepo) RetrieveBySecret(ctx context.Context, key, id string, pr
rows, err := repo.DB.NamedQueryContext(ctx, q, dbc)
if err != nil {
return clients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbc = DBClient{}
if rows.Next() {
if err = rows.StructScan(&dbc); err != nil {
return clients.Client{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
client, err := ToClient(dbc)
@@ -365,17 +369,17 @@ func (repo *clientRepo) RetrieveByIDWithRoles(ctx context.Context, id, memberID
}
row, err := repo.DB.NamedQueryContext(ctx, query, parameters)
if err != nil {
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbc := DBClient{}
if !row.Next() {
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
return clients.Client{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbc); err != nil {
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return ToClient(dbc)
@@ -391,14 +395,14 @@ func (repo *clientRepo) RetrieveByID(ctx context.Context, id string) (clients.Cl
row, err := repo.DB.NamedQueryContext(ctx, q, dbc)
if err != nil {
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbc = DBClient{}
if row.Next() {
if err := row.StructScan(&dbc); err != nil {
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return ToClient(dbc)
@@ -471,14 +475,14 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
if !pm.OnlyTotal {
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -497,7 +501,7 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -588,14 +592,14 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
if !pm.OnlyTotal {
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -639,7 +643,7 @@ func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID st
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -945,7 +949,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
@@ -953,7 +957,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -967,7 +971,7 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli
cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, tq)
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return clients.ClientsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -990,14 +994,14 @@ func (repo *clientRepo) update(ctx context.Context, client clients.Client, query
row, err := repo.DB.NamedQueryContext(ctx, query, dbc)
if err != nil {
return clients.Client{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
dbc = DBClient{}
if row.Next() {
if err := row.StructScan(&dbc); err != nil {
return clients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
return clients.Client{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return ToClient(dbc)
@@ -1014,7 +1018,7 @@ func (repo *clientRepo) Delete(ctx context.Context, clientIDs ...string) error {
}
result, err := repo.DB.NamedExecContext(ctx, q, params)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
+17 -14
View File
@@ -101,6 +101,7 @@ var (
"subgroup_set_parent",
"subgroup_update",
}
errClientSecretNotAvailable = errors.New("client key is not available")
)
func TestClientsSave(t *testing.T) {
@@ -189,7 +190,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrConflict,
err: errClientSecretNotAvailable,
},
{
desc: "add multiple clients with one client having duplicate secret",
@@ -216,7 +217,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrConflict,
err: errClientSecretNotAvailable,
},
{
desc: "add new client without domain id",
@@ -249,7 +250,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add multiple clients with one client having invalid client id",
@@ -275,7 +276,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add client with invalid client name",
@@ -292,7 +293,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add client with invalid client domain id",
@@ -308,7 +309,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add client with invalid client identity",
@@ -324,7 +325,7 @@ func TestClientsSave(t *testing.T) {
Status: clients.EnabledStatus,
},
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add client with a missing client identity",
@@ -390,14 +391,16 @@ func TestClientsSave(t *testing.T) {
},
}
for _, tc := range cases {
rClients, err := repo.Save(context.Background(), tc.clients...)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
for i := range rClients {
tc.clients[i].Credentials.Secret = rClients[i].Credentials.Secret
t.Run(tc.desc, func(t *testing.T) {
rClients, err := repo.Save(context.Background(), tc.clients...)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
for i := range rClients {
tc.clients[i].Credentials.Secret = rClients[i].Credentials.Secret
}
assert.Equal(t, tc.clients, rClients, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.clients, rClients))
}
assert.Equal(t, tc.clients, rClients, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.clients, rClients))
}
})
}
}
+24
View File
@@ -0,0 +1,24 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "clients_domain_id_secret_key":
return errors.NewRequestError("client key is not available"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+11 -10
View File
@@ -4,7 +4,6 @@ package clients
import (
"context"
"fmt"
"time"
smq "github.com/absmach/supermq"
@@ -20,9 +19,11 @@ import (
)
var (
errRollbackRepo = errors.New("failed to rollback repo")
errSetParentGroup = errors.New("client already have parent")
errSetSameParentGroup = errors.New("client already assigned to the parent group")
errRollbackRepo = errors.New("failed to rollback repo")
errSetParentGroup = errors.NewRequestError("client already have parent")
errSetSameParentGroup = errors.NewRequestError("client already assigned to the parent group")
errParentGroupDomainID = errors.NewRequestError("parent group has invalid domain id")
errParentGroupDisabled = errors.NewRequestError("parent group is not enabled")
)
var _ Service = (*service)(nil)
@@ -59,14 +60,14 @@ func (svc service) CreateClients(ctx context.Context, session authn.Session, cls
if c.ID == "" {
clientID, err := svc.idProvider.ID()
if err != nil {
return []Client{}, []roles.RoleProvision{}, err
return []Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
}
c.ID = clientID
}
if c.Credentials.Secret == "" {
key, err := svc.idProvider.ID()
if err != nil {
return []Client{}, []roles.RoleProvision{}, err
return []Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
}
c.Credentials.Secret = key
}
@@ -260,10 +261,10 @@ func (svc service) SetParentGroup(ctx context.Context, session authn.Session, pa
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
if resp.GetEntity().GetDomainId() != session.DomainID {
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s has invalid domain id", parentGroupID))
return errors.Wrap(svcerr.ErrUpdateEntity, errParentGroupDomainID)
}
if resp.GetEntity().GetStatus() != uint32(EnabledStatus) {
return errors.Wrap(svcerr.ErrUpdateEntity, fmt.Errorf("parent group id %s is not in enabled state", parentGroupID))
return errors.Wrap(svcerr.ErrUpdateEntity, errParentGroupDisabled)
}
var pols []policies.Policy
@@ -326,7 +327,7 @@ func (svc service) RemoveParentGroup(ctx context.Context, session authn.Session,
cli := Client{ID: id, UpdatedBy: session.UserID, UpdatedAt: time.Now().UTC()}
if err := svc.repo.RemoveParentGroup(ctx, cli); err != nil {
return err
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
}
return nil
@@ -388,7 +389,7 @@ func (svc service) changeClientStatus(ctx context.Context, session authn.Session
return Client{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if dbClient.Status == client.Status {
return Client{}, errors.ErrStatusAlreadyAssigned
return Client{}, svcerr.ErrStatusAlreadyAssigned
}
client.UpdatedBy = session.UserID
+4 -4
View File
@@ -771,8 +771,8 @@ func TestEnable(t *testing.T) {
client: enabledClient1,
changeStatusResponse: enabledClient1,
retrieveByIDResponse: enabledClient1,
changeStatusErr: errors.ErrStatusAlreadyAssigned,
err: errors.ErrStatusAlreadyAssigned,
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable non-existing client",
@@ -844,8 +844,8 @@ func TestDisable(t *testing.T) {
client: disabledClient1,
changeStatusResponse: clients.Client{},
retrieveByIDResponse: disabledClient1,
changeStatusErr: errors.ErrStatusAlreadyAssigned,
err: errors.ErrStatusAlreadyAssigned,
changeStatusErr: svcerr.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable non-existing client",
+5 -6
View File
@@ -21,7 +21,6 @@ const (
domainIDKey = "domain_id"
invitedByKey = "invited_by"
roleIDKey = "role_id"
roleNameKey = "role_name"
stateKey = "state"
)
@@ -31,7 +30,7 @@ func decodeCreateDomainRequest(_ context.Context, r *http.Request) (any, error)
}
req := createDomainReq{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -59,7 +58,7 @@ func decodeUpdateDomainRequest(_ context.Context, r *http.Request) (any, error)
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -191,7 +190,7 @@ func decodeSendInvitationReq(_ context.Context, r *http.Request) (any, error) {
var req sendInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -257,7 +256,7 @@ func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (any, error)
var req acceptInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -269,7 +268,7 @@ func decodeDeleteInvitationReq(_ context.Context, r *http.Request) (any, error)
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
+7 -7
View File
@@ -211,7 +211,7 @@ func TestCreateDomain(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "register domain with service error",
@@ -563,14 +563,14 @@ func TestListDomains(t *testing.T) {
status: http.StatusOK,
},
{
desc: "list domains with invalid dir",
desc: "list domains with invalid dir",
token: validToken,
query: "dir= ",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
},
{
desc: "list domains with duplicate dir",
desc: "list domains with duplicate dir",
token: validToken,
query: "dir=asc&dir=asc",
status: http.StatusBadRequest,
@@ -585,7 +585,7 @@ func TestListDomains(t *testing.T) {
Order: api.DefOrder,
Dir: api.DefDir,
},
status: http.StatusBadRequest,
status: http.StatusUnprocessableEntity,
listDomainsResp: domains.DomainsPage{},
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
@@ -652,10 +652,10 @@ func TestViewDomain(t *testing.T) {
err: svcerr.ErrAuthentication,
},
{
desc: "view domain with invalid id",
desc: "view domain with service error",
token: validToken,
domainID: invalid,
status: http.StatusBadRequest,
status: http.StatusUnprocessableEntity,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
},
@@ -822,7 +822,7 @@ func TestUpdateDomain(t *testing.T) {
},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update domain with invalid id",
+22 -30
View File
@@ -20,7 +20,6 @@ import (
"github.com/absmach/supermq/pkg/roles"
rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
@@ -31,12 +30,11 @@ const (
rolesTableNamePrefix = "domains"
entityTableName = "domains"
entityIDColumnName = "id"
pgDuplicateErrCode = "23505"
)
type domainRepo struct {
db postgres.Database
eh errors.Handler
rolesPostgres.Repository
}
@@ -44,8 +42,12 @@ type domainRepo struct {
// implementation of Domain repository.
func NewRepository(db postgres.Database) domains.Repository {
rmsvcRepo := rolesPostgres.NewRepository(db, policies.DomainType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &domainRepo{
db: db,
eh: postgres.NewErrorHandler(errHandlerOptions...),
Repository: rmsvcRepo,
}
}
@@ -62,7 +64,7 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
if err != nil {
return domains.Domain{}, handleSaveError(repoerr.ErrCreateEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
@@ -72,7 +74,7 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
dbd = dbDomain{}
if err := row.StructScan(&dbd); err != nil {
return domains.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
}
domain, err := toDomain(dbd)
@@ -83,16 +85,6 @@ func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd dom
return domain, nil
}
func handleSaveError(wrapper, err error) error {
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
switch pqErr.ConstraintName {
case "domains_route_key":
return errors.ErrRouteNotAvailable
}
}
return postgres.HandleError(wrapper, err)
}
// RetrieveDomainByIDWithRoles retrieves Domain by its unique ID along with member roles.
func (repo domainRepo) RetrieveDomainByIDWithRoles(ctx context.Context, id string, memberID string) (domains.Domain, error) {
q := `
@@ -176,14 +168,14 @@ func (repo domainRepo) RetrieveDomainByIDWithRoles(ctx context.Context, id strin
rows, err := repo.db.NamedQueryContext(ctx, q, dbdp)
if err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbd := dbDomain{}
if rows.Next() {
if err = rows.StructScan(&dbd); err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
domain, err := toDomain(dbd)
@@ -207,14 +199,14 @@ func (repo domainRepo) RetrieveDomainByID(ctx context.Context, id string) (domai
rows, err := repo.db.NamedQueryContext(ctx, q, dbdp)
if err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbd := dbDomain{}
if rows.Next() {
if err = rows.StructScan(&dbd); err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
domain, err := toDomain(dbd)
@@ -238,14 +230,14 @@ func (repo domainRepo) RetrieveDomainByRoute(ctx context.Context, route string)
rows, err := repo.db.NamedQueryContext(ctx, q, dbdom)
if err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbd := dbDomain{}
if rows.Next() {
if err = rows.StructScan(&dbd); err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
domain, err := toDomain(dbd)
@@ -284,13 +276,13 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
doms, err := repo.processRows(rows)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
cq := "SELECT COUNT(*) FROM domains d"
@@ -300,7 +292,7 @@ func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.P
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
return domains.DomainsPage{
@@ -370,13 +362,13 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
doms, err = repo.processRows(rows)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
@@ -393,7 +385,7 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
total, err := postgres.Total(ctx, repo.db, cq, dbPage)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return domains.DomainsPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
return domains.DomainsPage{
@@ -451,7 +443,7 @@ func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.D
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
if err != nil {
return domains.Domain{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
@@ -461,7 +453,7 @@ func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.D
dbd = dbDomain{}
if err := row.StructScan(&dbd); err != nil {
return domains.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return domains.Domain{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
}
domain, err := toDomain(dbd)
@@ -478,7 +470,7 @@ func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error {
res, err := repo.db.ExecContext(ctx, q, id)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := res.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
+4 -3
View File
@@ -21,8 +21,9 @@ import (
const invalid = "invalid"
var (
domainID = testsutil.GenerateUUID(&testing.T{})
userID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
userID = testsutil.GenerateUUID(&testing.T{})
errDomainExists = errors.New("domain already exists")
)
func TestSaveDomain(t *testing.T) {
@@ -72,7 +73,7 @@ func TestSaveDomain(t *testing.T) {
UpdatedBy: userID,
Status: domains.EnabledStatus,
},
err: repoerr.ErrConflict,
err: errDomainExists,
},
{
desc: "add domain with empty ID",
+26
View File
@@ -0,0 +1,26 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "domains_route_key":
return errors.ErrRouteNotAvailable, true
case "domains_pkey":
return errors.NewRequestError("domain already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+10 -20
View File
@@ -26,26 +26,16 @@ import (
)
const (
secret = "secret"
email = "test@example.com"
id = "testID"
groupName = "smqx"
description = "Description"
memberRelation = "member"
authoritiesObj = "authorities"
loginDuration = 30 * time.Minute
refreshDuration = 24 * time.Hour
invalidDuration = 7 * 24 * time.Hour
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
groupName = "smqx"
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
)
var (
ErrExpiry = errors.New("session is expired")
errAddPolicies = errors.New("failed to add policies")
errRollbackRepo = errors.New("failed to rollback repo")
inValid = "invalid"
valid = "valid"
domain = domains.Domain{
ErrExpiry = errors.New("session is expired")
errAddPolicies = errors.New("failed to add policies")
inValid = "invalid"
valid = "valid"
domain = domains.Domain{
ID: validID,
Name: groupName,
Tags: []string{"tag1", "tag2"},
@@ -173,7 +163,7 @@ func TestCreateDomain(t *testing.T) {
session: validSession,
addPoliciesErr: errAddPolicies,
deleteDomainErr: svcerr.ErrRemoveEntity,
err: errRollbackRepo,
err: svcerr.ErrRemoveEntity,
},
{
desc: "create domain with failed to add roles",
@@ -194,7 +184,7 @@ func TestCreateDomain(t *testing.T) {
session: validSession,
addRolesErr: errors.ErrMalformedEntity,
deleteDomainErr: errors.ErrMalformedEntity,
err: errRollbackRepo,
err: errors.ErrMalformedEntity,
},
}
@@ -206,7 +196,7 @@ func TestCreateDomain(t *testing.T) {
policyCall := policy.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr)
policyCall1 := policy.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
_, _, err := svc.CreateDomain(context.Background(), tc.session, tc.d)
assert.True(t, errors.Contains(err, tc.err))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
+6 -6
View File
@@ -22,7 +22,7 @@ func DecodeGroupCreate(_ context.Context, r *http.Request) (any, error) {
}
var g groups.Group
if err := json.NewDecoder(r.Body).Decode(&g); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
req := createGroupReq{
Group: g,
@@ -63,7 +63,7 @@ func DecodeGroupUpdate(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -77,7 +77,7 @@ func decodeUpdateGroupTags(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -126,7 +126,7 @@ func decodeAddParentGroupRequest(_ context.Context, r *http.Request) (any, error
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -146,7 +146,7 @@ func decodeAddChildrenGroupsRequest(_ context.Context, r *http.Request) (any, er
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -159,7 +159,7 @@ func decodeRemoveChildrenGroupsRequest(_ context.Context, r *http.Request) (any,
id: chi.URLParam(r, "groupID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
+2 -2
View File
@@ -347,7 +347,7 @@ func TestDecodeGroupCreate(t *testing.T) {
"Content-Type": {api.ContentType},
},
resp: nil,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
}
@@ -400,7 +400,7 @@ func TestDecodeGroupUpdate(t *testing.T) {
"Content-Type": {api.ContentType},
},
resp: nil,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
}
+12 -12
View File
@@ -152,7 +152,7 @@ func TestCreateGroupEndpoint(t *testing.T) {
},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrNameSize,
},
{
desc: "create group with name that is too long",
@@ -572,7 +572,7 @@ func TestUpdateGroupTagsEndpoint(t *testing.T) {
contentType: contentType,
data: fmt.Sprintf(`{"tags":["%s"}`, newTag),
status: http.StatusBadRequest,
err: errors.ErrMalformedEntity,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update group with empty id",
@@ -886,7 +886,7 @@ func TestListGroups(t *testing.T) {
token: validToken,
query: "offset=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list groups with limit",
@@ -908,7 +908,7 @@ func TestListGroups(t *testing.T) {
token: validToken,
query: "limit=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list groups with limit greater than max",
@@ -916,7 +916,7 @@ func TestListGroups(t *testing.T) {
domainID: validID,
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrLimitSize,
},
{
desc: "list groups with name",
@@ -968,7 +968,7 @@ func TestListGroups(t *testing.T) {
token: validToken,
query: "status=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: svcerr.ErrInvalidStatus,
},
{
desc: "list groups with duplicate status",
@@ -1028,7 +1028,7 @@ func TestListGroups(t *testing.T) {
token: validToken,
query: "metadata=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list groups with duplicate metadata",
@@ -1330,7 +1330,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
domainID: validID,
query: "level=invalid&dir=-1&tree=false",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "retrieve group hierarchy with invalid direction",
@@ -1339,7 +1339,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
domainID: validID,
query: "level=1&dir=invalid&tree=false",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "retrieve group hierarchy with invalid tree",
@@ -1348,7 +1348,7 @@ func TestRetrieveGroupHierarchyEndpoint(t *testing.T) {
domainID: validID,
query: "level=1&dir=-1&tree=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "retrieve group hierarchy with empty groupID",
@@ -2116,7 +2116,7 @@ func TestListChildrenGroupsEndpoint(t *testing.T) {
domainID: validID,
query: "limit=invalid&offset=0",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list children groups with invalid offset",
@@ -2125,7 +2125,7 @@ func TestListChildrenGroupsEndpoint(t *testing.T) {
domainID: validID,
query: "limit=1&offset=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list children groups with empty id",
+26
View File
@@ -0,0 +1,26 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
var errCyclicParentGroup = errors.NewRequestError("cyclic parent, group is parent of requested group")
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "groups_pkey":
return errors.NewRequestError("group id already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+49 -45
View File
@@ -41,6 +41,7 @@ var (
type groupRepository struct {
db postgres.Database
eh errors.Handler
rolesPostgres.Repository
}
@@ -48,9 +49,12 @@ type groupRepository struct {
// repository.
func New(db postgres.Database) groups.Repository {
roleRepo := rolesPostgres.NewRepository(db, policies.GroupType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &groupRepository{
db: db,
eh: postgres.NewErrorHandler(errHandlerOptions...),
Repository: roleRepo,
}
}
@@ -62,19 +66,19 @@ func (repo groupRepository) Save(ctx context.Context, g groups.Group) (groups.Gr
}
dbg, err := toDBGroup(g)
if err != nil {
return groups.Group{}, err
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return groups.Group{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
row.Next()
dbg = dbGroup{}
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, err
return groups.Group{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
return toGroup(dbg)
@@ -107,7 +111,7 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
row, err := repo.db.NamedQueryContext(ctx, q, dbu)
if err != nil {
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
@@ -116,7 +120,7 @@ func (repo groupRepository) Update(ctx context.Context, g groups.Group) (groups.
}
dbu = dbGroup{}
if err := row.StructScan(&dbu); err != nil {
return groups.Group{}, errors.Wrap(err, repoerr.ErrUpdateEntity)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return toGroup(dbu)
}
@@ -134,14 +138,14 @@ func (repo groupRepository) UpdateTags(ctx context.Context, group groups.Group)
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
dbg = dbGroup{}
if row.Next() {
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return toGroup(dbg)
@@ -160,7 +164,7 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group
}
row, err := repo.db.NamedQueryContext(ctx, qc, dbg)
if err != nil {
return groups.Group{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
if ok := row.Next(); !ok {
@@ -168,7 +172,7 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group groups.Group
}
dbg = dbGroup{}
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(err, repoerr.ErrUpdateEntity)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return toGroup(dbg)
@@ -184,7 +188,7 @@ func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (groups
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
@@ -193,7 +197,7 @@ func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (groups
return groups.Group{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toGroup(dbg)
}
@@ -341,17 +345,17 @@ func (repo groupRepository) RetrieveByIDWithRoles(ctx context.Context, id, membe
}
row, err := repo.db.NamedQueryContext(ctx, query, parameters)
if err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbg := dbGroup{}
if !row.Next() {
return groups.Group{}, errors.Wrap(repoerr.ErrNotFound, err)
return groups.Group{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toGroup(dbg)
@@ -394,7 +398,7 @@ func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, u
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
@@ -403,7 +407,7 @@ func (repo groupRepository) RetrieveByIDAndUser(ctx context.Context, domainID, u
return groups.Group{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbg); err != nil {
return groups.Group{}, errors.Wrap(repoerr.ErrViewEntity, err)
return groups.Group{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return toGroup(dbg)
}
@@ -446,13 +450,13 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err = repo.processRows(rows)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
@@ -465,7 +469,7 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, pm groups.PageMeta)
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
page := groups.Page{PageMeta: pm}
@@ -490,13 +494,13 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
}
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err := repo.processRows(rows)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
cq := fmt.Sprintf(` SELECT COUNT(*) AS total_count
@@ -508,7 +512,7 @@ func (repo groupRepository) RetrieveByIDs(ctx context.Context, pm groups.PageMet
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
page := groups.Page{PageMeta: pm}
@@ -570,13 +574,13 @@ func (repo groupRepository) RetrieveHierarchy(ctx context.Context, domainID, use
rows, err := repo.db.NamedQueryContext(ctx, query, parameters)
if err != nil {
return groups.HierarchyPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.HierarchyPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err := repo.processRows(rows)
if err != nil {
return groups.HierarchyPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.HierarchyPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
return groups.HierarchyPage{HierarchyPageMeta: hm, Groups: items}, nil
@@ -589,7 +593,7 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
tx, err := repo.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer func() {
if err != nil {
@@ -602,13 +606,13 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
pq := `SELECT id, path FROM groups WHERE id = $1 LIMIT 1;`
rows, err := tx.Queryx(pq, parentGroupID)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
pGroups, err := repo.processRows(rows)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if len(pGroups) == 0 {
return repoerr.ErrUpdateEntity
@@ -628,7 +632,7 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
for _, sPath := range sPaths {
for _, cgid := range groupIDs {
if sPath == cgid {
return errors.Wrap(repoerr.ErrUpdateEntity, fmt.Errorf("cyclic parent, group %s is parent of requested group %s", cgid, parentGroupID))
return errors.Wrap(repoerr.ErrUpdateEntity, errCyclicParentGroup)
}
}
}
@@ -645,12 +649,12 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
crows, err := tx.NamedQuery(query, params)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer crows.Close()
cgroups, err := repo.processRows(crows)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
childrenPaths := []string{}
@@ -666,11 +670,11 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
WHERE path <@ ANY($2::ltree[]);`
if _, err := tx.Exec(query, pGroup.Path, childrenPaths); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if err := tx.Commit(); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return nil
}
@@ -682,7 +686,7 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
tx, err := repo.db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer func() {
if err != nil {
@@ -694,13 +698,13 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
pq := `SELECT id, path FROM groups WHERE id = $1 LIMIT 1;`
rows, err := tx.Queryx(pq, parentGroupID)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
pGroups, err := repo.processRows(rows)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if len(pGroups) == 0 {
return repoerr.ErrUpdateEntity
@@ -725,12 +729,12 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
}
crows, err := tx.NamedQuery(query, parameters)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer crows.Close()
cgroups, err := repo.processRows(crows)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
childrenPaths := []string{}
@@ -746,11 +750,11 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
WHERE path <@ ANY($2::ltree[]);`
if _, err := tx.Exec(query, pGroup.Path, childrenPaths); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if err := tx.Commit(); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
return nil
}
@@ -764,7 +768,7 @@ func (repo groupRepository) UnassignAllChildrenGroups(ctx context.Context, id st
result, err := repo.db.NamedExecContext(ctx, query, dbGroup{ParentID: &id})
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
return repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -778,7 +782,7 @@ func (repo groupRepository) Delete(ctx context.Context, groupID string) error {
result, err := repo.db.ExecContext(ctx, q, groupID)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -922,13 +926,13 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
if !pm.OnlyTotal {
rows, err := repo.db.NamedQueryContext(ctx, q, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
items, err = repo.processRows(rows)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
}
@@ -943,7 +947,7 @@ func (repo groupRepository) retrieveGroups(ctx context.Context, domainID, userID
total, err := postgres.Total(ctx, repo.db, cq, dbPageMeta)
if err != nil {
return groups.Page{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return groups.Page{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
page := groups.Page{PageMeta: pm}
+7 -6
View File
@@ -61,6 +61,7 @@ var (
"subgroup_remove_role_users",
"subgroup_view_role_users",
}
errGroupExists = errors.NewRequestError("group id already exists")
)
func TestSave(t *testing.T) {
@@ -106,7 +107,7 @@ func TestSave(t *testing.T) {
{
desc: "add duplicate group",
group: validGroup,
err: repoerr.ErrConflict,
err: errGroupExists,
},
{
desc: "add group with parent",
@@ -125,7 +126,7 @@ func TestSave(t *testing.T) {
CreatedAt: validTimestamp,
Status: groups.EnabledStatus,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add group with invalid domain",
@@ -138,7 +139,7 @@ func TestSave(t *testing.T) {
CreatedAt: validTimestamp,
Status: groups.EnabledStatus,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add group with invalid parent",
@@ -164,7 +165,7 @@ func TestSave(t *testing.T) {
CreatedAt: validTimestamp,
Status: groups.EnabledStatus,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add group with invalid description",
@@ -177,7 +178,7 @@ func TestSave(t *testing.T) {
CreatedAt: validTimestamp,
Status: groups.EnabledStatus,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add group with invalid metadata",
@@ -205,7 +206,7 @@ func TestSave(t *testing.T) {
CreatedAt: validTimestamp,
Status: groups.EnabledStatus,
},
err: repoerr.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add group with duplicate name",
+13 -8
View File
@@ -19,7 +19,12 @@ import (
"github.com/absmach/supermq/pkg/roles"
)
var ErrGroupIDs = errors.New("invalid group ids")
var (
ErrGroupIDs = errors.New("invalid group ids")
errChangeGroupStatus = errors.NewServiceError("failed to change group status")
errGroupHaveParent = errors.NewRequestError("group already have parent")
errDifferentParent = errors.NewRequestError("groups have different parent")
)
type service struct {
repo Repository
@@ -50,7 +55,7 @@ func NewService(repo Repository, policy policies.Service, idp supermq.IDProvider
func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g Group) (retGr Group, retRps []roles.RoleProvision, retErr error) {
groupID, err := svc.idProvider.ID()
if err != nil {
return Group{}, []roles.RoleProvision{}, err
return Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
if g.Status != EnabledStatus && g.Status != DisabledStatus {
return Group{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus
@@ -180,7 +185,7 @@ func (svc service) EnableGroup(ctx context.Context, session smqauthn.Session, id
}
group, err := svc.changeGroupStatus(ctx, session, group)
if err != nil {
return Group{}, err
return Group{}, errors.Wrap(errChangeGroupStatus, err)
}
return group, nil
}
@@ -193,7 +198,7 @@ func (svc service) DisableGroup(ctx context.Context, session smqauthn.Session, i
}
group, err := svc.changeGroupStatus(ctx, session, group)
if err != nil {
return Group{}, err
return Group{}, errors.Wrap(errChangeGroupStatus, err)
}
return group, nil
}
@@ -290,7 +295,7 @@ func (svc service) AddChildrenGroups(ctx context.Context, session smqauthn.Sessi
for _, childGroup := range childrenGroupsPage.Groups {
if childGroup.Parent != "" {
return errors.Wrap(svcerr.ErrConflict, fmt.Errorf("%s group already have parent", childGroup.ID))
return errors.Wrap(svcerr.ErrConflict, errGroupHaveParent)
}
}
@@ -336,7 +341,7 @@ func (svc service) RemoveChildrenGroups(ctx context.Context, session smqauthn.Se
for _, group := range childrenGroupsPage.Groups {
if group.Parent != "" && group.Parent != parentGroupID {
return errors.Wrap(svcerr.ErrConflict, fmt.Errorf("%s group doesn't have same parent", group.ID))
return errors.Wrap(svcerr.ErrConflict, errDifferentParent)
}
pols = append(pols, policies.Policy{
Domain: session.DomainID,
@@ -440,7 +445,7 @@ func (svc service) DeleteGroup(ctx context.Context, session smqauthn.Session, id
}
if err := svc.repo.Delete(ctx, id); err != nil {
return err
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
@@ -452,7 +457,7 @@ func (svc service) changeGroupStatus(ctx context.Context, session smqauthn.Sessi
return Group{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if dbGroup.Status == group.Status {
return Group{}, errors.ErrStatusAlreadyAssigned
return Group{}, svcerr.ErrStatusAlreadyAssigned
}
group.UpdatedBy = session.UserID
+10 -11
View File
@@ -85,9 +85,8 @@ var (
Status: groups.EnabledStatus,
Children: children,
}
validID = testsutil.GenerateUUID(&testing.T{})
errRollbackRoles = errors.New("failed to rollback roles")
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
validID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID}
)
var (
@@ -165,7 +164,7 @@ func TestCreateGroup(t *testing.T) {
group: validGroup,
saveResp: groups.Group{},
saveErr: errors.ErrMalformedEntity,
err: errors.Wrap(svcerr.ErrCreateEntity, errors.ErrMalformedEntity),
err: errors.ErrMalformedEntity,
},
{
desc: " create group with failed to add policies",
@@ -176,7 +175,7 @@ func TestCreateGroup(t *testing.T) {
Domain: validID,
},
addPoliciesErr: svcerr.ErrAuthorization,
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrAuthorization)),
err: svcerr.ErrAddPolicies,
},
{
desc: " create group with failed to add policies and failed rollback",
@@ -188,7 +187,7 @@ func TestCreateGroup(t *testing.T) {
},
addPoliciesErr: svcerr.ErrAuthorization,
deleteErr: svcerr.ErrRemoveEntity,
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(apiutil.ErrRollbackTx, svcerr.ErrRemoveEntity)),
err: svcerr.ErrRemoveEntity,
},
{
desc: "create group with failed to add roles",
@@ -199,7 +198,7 @@ func TestCreateGroup(t *testing.T) {
Domain: validID,
},
addRoleErr: svcerr.ErrCreateEntity,
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, svcerr.ErrCreateEntity)),
err: svcerr.ErrAddPolicies,
},
{
desc: "create groups with failed to add roles and failed to delete policies",
@@ -211,7 +210,7 @@ func TestCreateGroup(t *testing.T) {
},
addRoleErr: svcerr.ErrCreateEntity,
deletePoliciesErr: svcerr.ErrRemoveEntity,
err: errors.Wrap(svcerr.ErrAddPolicies, errors.Wrap(svcerr.ErrCreateEntity, errors.Wrap(errRollbackRoles, svcerr.ErrRemoveEntity))),
err: svcerr.ErrAddPolicies,
},
}
@@ -223,7 +222,7 @@ func TestCreateGroup(t *testing.T) {
repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr)
repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr)
got, _, err := svc.CreateGroup(context.Background(), validSession, tc.group)
assert.Equal(t, tc.err, err, fmt.Sprintf("expected error %v but got %v", tc.err, err))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v but got %v", tc.err, err))
if err == nil {
assert.NotEmpty(t, got.ID)
assert.NotEmpty(t, got.CreatedAt)
@@ -418,7 +417,7 @@ func TestEnableGroup(t *testing.T) {
retrieveResp: groups.Group{
Status: groups.EnabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable group with retrieve error",
@@ -472,7 +471,7 @@ func TestDisableGroup(t *testing.T) {
retrieveResp: groups.Group{
Status: groups.DisabledStatus,
},
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable group with retrieve error",
+6 -2
View File
@@ -133,7 +133,9 @@ func TestAuthConnect(t *testing.T) {
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err))
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
}
})
}
}
@@ -450,7 +452,9 @@ func TestPublish(t *testing.T) {
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
}
authCall.Unset()
repoCall.Unset()
clientsCall.Unset()
+24
View File
@@ -0,0 +1,24 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "journal_pkey":
return errors.NewRequestError("journal entry already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+13 -6
View File
@@ -18,10 +18,17 @@ import (
type repository struct {
db postgres.Database
eh errors.Handler
}
func NewRepository(db postgres.Database) journal.Repository {
return &repository{db: db}
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &repository{
db: db,
eh: postgres.NewErrorHandler(errHandlerOptions...),
}
}
func (repo *repository) Save(ctx context.Context, j journal.Journal) (err error) {
@@ -41,11 +48,11 @@ func (repo *repository) Save(ctx context.Context, j journal.Journal) (err error)
dbJournal, err := toDBJournal(j)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
return repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
if _, err = repo.db.NamedExecContext(ctx, q, dbJournal); err != nil {
return postgres.HandleError(repoerr.ErrCreateEntity, err)
return repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
return nil
@@ -68,7 +75,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
rows, err := repo.db.NamedQueryContext(ctx, q, page)
if err != nil {
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -76,7 +83,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
for rows.Next() {
var item dbJournal
if err = rows.StructScan(&item); err != nil {
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
j, err := toJournal(item)
if err != nil {
@@ -89,7 +96,7 @@ func (repo *repository) RetrieveAll(ctx context.Context, page journal.Page) (jou
total, err := postgres.Total(ctx, repo.db, tq, page)
if err != nil {
return journal.JournalsPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return journal.JournalsPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
journalsPage := journal.JournalsPage{
+3 -2
View File
@@ -87,7 +87,8 @@ var (
"created_at": time.Now().Add(-time.Hour),
"name": "group",
}
validTimeStamp = time.Now().UTC().Truncate(time.Millisecond)
validTimeStamp = time.Now().UTC().Truncate(time.Millisecond)
errJournalExists = errors.NewRequestError("journal entry already exists")
)
func TestJournalSave(t *testing.T) {
@@ -125,7 +126,7 @@ func TestJournalSave(t *testing.T) {
Attributes: payload,
Metadata: payload,
},
err: repoerr.ErrConflict,
err: errJournalExists,
},
{
desc: "with massive journal metadata and attributes",
+14 -10
View File
@@ -417,7 +417,7 @@ func TestPublish(t *testing.T) {
session: &sessionClient,
topic: malformedSubtopics,
payload: payload,
err: errors.Wrap(mqtt.ErrFailedPublish, errors.Wrap(messaging.ErrMalformedTopic, errors.Wrap(messaging.ErrMalformedSubtopic, errors.New("invalid URL escape \"%\"")))),
err: errors.New("invalid URL escape \"%\""),
},
{
desc: "publish with subtopic containing wrong character",
@@ -457,15 +457,19 @@ func TestPublish(t *testing.T) {
}
for _, tc := range cases {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
assert.Contains(t, logBuffer.String(), tc.logMsg)
assert.Equal(t, tc.err, err)
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
if tc.session != nil {
ctx = session.NewContext(ctx, tc.session)
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
assert.Contains(t, logBuffer.String(), tc.logMsg)
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err))
}
repoCall.Unset()
})
}
}
+244
View File
@@ -0,0 +1,244 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package errors
import (
"errors"
"fmt"
)
type NestError interface {
Error
Embed(e error) error
}
var _ NestError = (*customError)(nil)
func (e *customError) Embed(err error) error {
if err == nil {
return e
}
return &customError{
msg: e.msg,
err: fmt.Errorf("%w: %w", e.err, err),
}
}
type RequestError struct {
customError
}
var _ NestError = (*RequestError)(nil)
func NewRequestError(message string) NestError {
return &RequestError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewRequestErrorWithErr(message string, err error) NestError {
return &RequestError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
func (e *RequestError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &RequestError{
customError: *embedded.(*customError),
}
}
type AuthNError struct {
customError
}
var _ NestError = (*AuthNError)(nil)
func NewAuthNError(message string) NestError {
return &AuthNError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewAuthNErrorWithErr(message string, err error) NestError {
return &AuthNError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
func (e *AuthNError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &AuthNError{
customError: *embedded.(*customError),
}
}
var _ NestError = (*AuthZError)(nil)
type AuthZError struct {
customError
}
func (e *AuthZError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &AuthZError{
customError: *embedded.(*customError),
}
}
func NewAuthZError(message string) NestError {
return &AuthZError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewAuthZErrorWithErr(message string, err error) NestError {
return &AuthZError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
type InternalError struct {
customError
}
var _ NestError = (*InternalError)(nil)
func NewInternalError() error {
return &InternalError{
customError: customError{
msg: "internal server error",
err: errors.New("internal server error"),
},
}
}
func NewInternalErrorWithErr(err error) NestError {
return &InternalError{
customError: customError{
msg: "internal server error",
err: fmt.Errorf("%w: %w", errors.New("internal server error"), err),
},
}
}
func (e *InternalError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &InternalError{
customError: *embedded.(*customError),
}
}
type ServiceError struct {
customError
}
var _ NestError = (*ServiceError)(nil)
func NewServiceError(message string) NestError {
return &ServiceError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewServiceErrorWithErr(message string, err error) NestError {
return &ServiceError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
func (e *ServiceError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &ServiceError{
customError: *embedded.(*customError),
}
}
type MediaTypeError struct {
customError
}
var _ NestError = (*MediaTypeError)(nil)
func NewMediaTypeError(message string) NestError {
return &MediaTypeError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewMediaTypeErrorWithErr(message string, err error) NestError {
return &MediaTypeError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
func (e *MediaTypeError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &MediaTypeError{
customError: *embedded.(*customError),
}
}
type NotFoundError struct {
customError
}
var _ NestError = (*NotFoundError)(nil)
func NewNotFoundError(message string) NestError {
return &NotFoundError{
customError: customError{
msg: message,
err: errors.New(message),
},
}
}
func NewNotFoundErrorWithErr(message string, err error) NestError {
return &NotFoundError{
customError: customError{
msg: message,
err: fmt.Errorf("%w: %w", errors.New(message), err),
},
}
}
func (e *NotFoundError) Embed(err error) error {
embedded := e.customError.Embed(err)
return &NotFoundError{
customError: *embedded.(*customError),
}
}
+15 -31
View File
@@ -5,6 +5,8 @@ package errors
import (
"encoding/json"
"errors"
"fmt"
)
// Error specifies an API that must be fullfiled by error type.
@@ -16,7 +18,7 @@ type Error interface {
Msg() string
// Err returns wrapped error.
Err() Error
Err() error
// MarshalJSON returns a marshaled error.
MarshalJSON() ([]byte, error)
@@ -27,14 +29,14 @@ var _ Error = (*customError)(nil)
// customError represents a SuperMQ error.
type customError struct {
msg string
err Error
err error
}
// New returns an Error that formats as the given text.
func New(text string) Error {
return &customError{
msg: text,
err: nil,
err: errors.New(text),
}
}
@@ -45,32 +47,25 @@ func (ce *customError) Error() string {
if ce.err == nil {
return ce.msg
}
return ce.msg + " : " + ce.err.Error()
return ce.err.Error()
}
func (ce *customError) Msg() string {
return ce.msg
}
func (ce *customError) Err() Error {
func (ce *customError) Err() error {
return ce.err
}
func (ce *customError) MarshalJSON() ([]byte, error) {
var val string
if e := ce.Err(); e != nil {
val = e.Msg()
}
return json.Marshal(&struct {
Err string `json:"error"`
Msg string `json:"message"`
}{
Err: val,
Msg: ce.Msg(),
})
}
// Contains inspects if e2 error is contained in any layer of e1 error.
func Contains(e1, e2 error) bool {
if e1 == nil || e2 == nil {
return e2 == e1
@@ -82,7 +77,8 @@ func Contains(e1, e2 error) bool {
}
return Contains(ce.Err(), e2)
}
return e1.Error() == e2.Error()
return errors.Is(e1, e2) || e1.Error() == e2.Error()
}
// Wrap returns an Error that wrap err with wrapper.
@@ -90,30 +86,18 @@ func Wrap(wrapper, err error) error {
if wrapper == nil || err == nil {
return wrapper
}
if w, ok := wrapper.(Error); ok {
return &customError{
msg: w.Msg(),
err: cast(err),
}
if ne, ok := err.(NestError); ok {
return ne.Embed(wrapper)
}
if ce, ok := wrapper.(NestError); ok {
return ce.Embed(err)
}
return &customError{
msg: wrapper.Error(),
err: cast(err),
err: fmt.Errorf("%w: %w", wrapper, err),
}
}
// Unwrap returns the wrapper and the error by separating the Wrapper from the error.
func Unwrap(err error) (error, error) {
if ce, ok := err.(Error); ok {
if ce.Err() == nil {
return nil, New(ce.Msg())
}
return New(ce.Msg()), ce.Err()
}
return nil, err
}
func cast(err error) Error {
if err == nil {
return nil
+9 -68
View File
@@ -34,35 +34,35 @@ func TestError(t *testing.T) {
desc: "level 0 wrapped error",
err: err0,
msg: "0",
bytes: []byte(`{"error":"","message":"0"}`),
bytes: []byte(`{"message":"0"}`),
bytesErr: nil,
},
{
desc: "level 1 wrapped error",
err: wrap(1),
msg: message(1),
bytes: []byte(`{"error":"0","message":"1"}`),
bytes: []byte(`{"message":"0"}`),
bytesErr: nil,
},
{
desc: "level 2 wrapped error",
err: wrap(2),
msg: message(2),
bytes: []byte(`{"error":"1","message":"2"}`),
bytes: []byte(`{"message":"0"}`),
bytesErr: nil,
},
{
desc: fmt.Sprintf("level %d wrapped error", level),
err: wrap(level),
msg: message(level),
bytes: []byte(`{"error":"9","message":"` + strconv.Itoa(level) + `"}`),
bytes: []byte(`{"message":"0"}`),
bytesErr: nil,
},
{
desc: "nil error",
err: errors.New(""),
msg: "",
bytes: []byte(`{"error":"","message":""}`),
bytes: []byte(`{"message":""}`),
bytesErr: nil,
},
}
@@ -129,9 +129,9 @@ func TestContains(t *testing.T) {
contains: true,
},
{
desc: fmt.Sprintf("level %d wrapped error contains", level),
desc: fmt.Sprintf("level %d wrapped error contains err0", level),
container: wrap(level),
contained: errors.New(strconv.Itoa(level / 2)),
contained: err0,
contains: true,
},
{
@@ -276,66 +276,6 @@ func TestWrap(t *testing.T) {
}
}
func TestUnwrap(t *testing.T) {
cases := []struct {
desc string
err error
wrapper error
wrapped error
}{
{
desc: "err 1 wraped err 2",
err: errors.Wrap(err1, err2),
wrapper: err1,
wrapped: err2,
},
{
desc: "err2 wraps err1 wraps err0",
err: errors.Wrap(err2, errors.Wrap(err1, err0)),
wrapper: err2,
wrapped: errors.Wrap(err1, err0),
},
{
desc: "nil wraps nil",
err: errors.Wrap(nil, nil),
wrapper: nil,
wrapped: nil,
},
{
desc: "err0 wraps nil",
err: errors.Wrap(err0, nil),
wrapper: nil,
wrapped: err0,
},
{
desc: "nil wraps err0",
err: errors.Wrap(nil, err0),
wrapper: nil,
wrapped: nil,
},
{
desc: "nil wraps native error",
err: errors.Wrap(nil, nat),
wrapper: nil,
wrapped: nil,
},
{
desc: "native error wraps nil",
err: errors.Wrap(nat, nil),
wrapper: nil,
wrapped: nat,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
wrapper, wrapped := errors.Unwrap(c.err)
assert.Equal(t, c.wrapper, wrapper)
assert.Equal(t, c.wrapped, wrapped)
})
}
}
func wrap(level int) error {
if level == 0 {
return errors.New(strconv.Itoa(level))
@@ -344,9 +284,10 @@ func wrap(level int) error {
}
// message generates error message of wrap() generated wrapper error.
// The error message format is now "innermost: ... : outermost" due to fmt.Errorf wrapping.
func message(level int) string {
if level == 0 {
return "0"
}
return strconv.Itoa(level) + " : " + message(level-1)
return message(level-1) + ": " + strconv.Itoa(level)
}
+14
View File
@@ -0,0 +1,14 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package errors
type Mapper interface {
GetError(key string) (error, bool)
}
type Handler interface {
HandleError(wrapper, err error) error
}
type HandlerOption func(*Handler)
+10 -1
View File
@@ -11,7 +11,7 @@ var (
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrNotFound indicates a non-existent entity request.
ErrNotFound = errors.New("entity not found")
ErrNotFound = errors.NewNotFoundError("entity not found")
// ErrConflict indicates that entity already exists.
ErrConflict = errors.New("entity already exists")
@@ -39,4 +39,13 @@ var (
// ErrMissingNames indicates missing first and last names.
ErrMissingNames = errors.New("missing first or last name")
// ErrMarshalBDEntity indicates a failure to marshal a database entity.
ErrMarshalBDEntity = errors.New("failed to marshal db entity")
// ErrUnmarshalBDEntity indicates a failure to unmarshal a database entity.
ErrUnmarshalBDEntity = errors.New("failed to unmarshal db entity")
// ErrParseQueryParams indicates a failure to parse query parameters.
ErrParseQueryParams = errors.NewRequestError("failed to parse query parameters")
)
+16 -7
View File
@@ -43,8 +43,11 @@ func TestNewSDKError(t *testing.T) {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.NewSDKError(c.err)
if c.err != nil {
assert.Equal(t, sdk.StatusCode(), 0)
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(0), c.err.Error()))
assert.NotNil(t, sdk)
assert.Equal(t, 0, sdk.StatusCode())
assert.Equal(t, fmt.Sprintf("Status: %s: %s", http.StatusText(0), c.err.Error()), sdk.Error())
} else {
assert.Nil(t, sdk)
}
})
}
@@ -102,8 +105,11 @@ func TestNewSDKErrorWithStatus(t *testing.T) {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.NewSDKErrorWithStatus(c.err, c.sc)
if c.err != nil {
assert.Equal(t, sdk.StatusCode(), c.sc)
assert.Equal(t, sdk.Error(), fmt.Sprintf("Status: %s: %s", http.StatusText(c.sc), c.err.Error()))
assert.NotNil(t, sdk)
assert.Equal(t, c.sc, sdk.StatusCode())
assert.Equal(t, fmt.Sprintf("Status: %s: %s", http.StatusText(c.sc), c.err.Error()), sdk.Error())
} else {
assert.Nil(t, sdk)
}
})
}
@@ -196,10 +202,13 @@ func TestCheckError(t *testing.T) {
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
sdk := errors.CheckError(c.resp, c.codes...)
assert.Equal(t, sdk, c.err)
if c.err != nil {
assert.Equal(t, sdk, c.err)
assert.Equal(t, sdk.StatusCode(), c.resp.StatusCode)
assert.NotNil(t, sdk)
assert.Equal(t, c.err.StatusCode(), sdk.StatusCode())
assert.Equal(t, c.err.Error(), sdk.Error())
assert.Equal(t, c.resp.StatusCode, sdk.StatusCode())
} else {
assert.Nil(t, sdk)
}
})
}
+70 -55
View File
@@ -8,101 +8,116 @@ import "github.com/absmach/supermq/pkg/errors"
// Wrapper for Service errors.
var (
// ErrAuthentication indicates failure occurred while authenticating the entity.
ErrAuthentication = errors.New("failed to perform authentication over the entity")
// ErrAuthorization indicates failure occurred while authorizing the entity.
ErrAuthorization = errors.New("failed to perform authorization over the entity")
// ErrDomainAuthorization indicates failure occurred while authorizing the domain.
ErrDomainAuthorization = errors.New("failed to perform authorization over the domain")
ErrAuthentication = errors.NewAuthNError("failed to perform authentication over the entity")
// ErrLogin indicates wrong login credentials.
ErrLogin = errors.New("invalid credentials")
ErrLogin = errors.NewAuthNError("invalid credentials")
// ErrAuthorization indicates failure occurred while authorizing the entity.
ErrAuthorization = errors.NewAuthZError("failed to perform authorization over the entity")
// ErrDomainAuthorization indicates failure occurred while authorizing the domain.
ErrDomainAuthorization = errors.NewAuthZError("failed to perform authorization over the domain")
// ErrUnauthorizedPAT indicates failure occurred while authorizing PAT.
ErrUnauthorizedPAT = errors.NewAuthZError("failed to authorize PAT")
// ErrSuperAdminAction indicates that the user is not a super admin.
ErrSuperAdminAction = errors.NewAuthZError("not authorized to perform admin action")
// ErrCreateEntity indicates error in creating entity or entities.
ErrCreateEntity = errors.NewServiceError("failed to create entity")
// ErrRemoveEntity indicates error in removing entity.
ErrRemoveEntity = errors.NewServiceError("failed to remove entity")
// ErrViewEntity indicates error in viewing entity or entities.
ErrViewEntity = errors.NewServiceError("view entity failed")
// ErrUpdateEntity indicates error in updating entity or entities.
ErrUpdateEntity = errors.NewServiceError("update entity failed")
// ErrAddPolicies indicates error in adding policies.
ErrAddPolicies = errors.NewServiceError("failed to add policies")
// ErrUserAlreadyVerified indicates user is already verified.
ErrUserAlreadyVerified = errors.NewServiceError("user already verified")
// ErrInvalidUserVerification indicates user verification is invalid.
ErrInvalidUserVerification = errors.NewServiceError("invalid verification")
// ErrIssueProviderID indicates failure to issue unique ID from ID provider.
ErrIssueProviderID = errors.NewServiceError("failed to issue unique ID from id provider")
// ErrHashPassword indicates failure to hash password.
ErrHashPassword = errors.NewServiceError("failed to hash password")
// ErrStatusAlreadyAssigned indicates that the client or group has already been assigned the status.
ErrStatusAlreadyAssigned = errors.NewServiceError("status already assigned")
// ErrDeletePolicies indicates error in removing policies.
ErrDeletePolicies = errors.NewServiceError("failed to remove policies")
// ErrMissingUsername indicates that the user's names are missing.
ErrMissingUsername = errors.NewRequestError("missing usernames")
// ErrInvalidStatus indicates an invalid status.
ErrInvalidStatus = errors.NewRequestError("invalid status")
// ErrInvalidRole indicates that an invalid role.
ErrInvalidRole = errors.NewRequestError("invalid client role")
// ErrMalformedEntity indicates a malformed entity specification.
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrNotFound indicates a non-existent entity request.
ErrNotFound = errors.New("entity not found")
ErrNotFound = errors.NewNotFoundError("entity not found")
// ErrConflict indicates that entity already exists.
ErrConflict = errors.New("entity already exists")
// ErrCreateEntity indicates error in creating entity or entities.
ErrCreateEntity = errors.New("failed to create entity")
// ErrRemoveEntity indicates error in removing entity.
ErrRemoveEntity = errors.New("failed to remove entity")
// ErrViewEntity indicates error in viewing entity or entities.
ErrViewEntity = errors.New("view entity failed")
// ErrUpdateEntity indicates error in updating entity or entities.
ErrUpdateEntity = errors.New("update entity failed")
// ErrInvalidStatus indicates an invalid status.
ErrInvalidStatus = errors.New("invalid status")
// ErrInvalidRole indicates that an invalid role.
ErrInvalidRole = errors.New("invalid client role")
ErrConflict = errors.NewRequestError("entity already exists")
// ErrInvalidPolicy indicates that an invalid policy.
ErrInvalidPolicy = errors.New("invalid policy")
// ErrEnableClient indicates error in enabling client.
ErrEnableClient = errors.New("failed to enable client")
ErrEnableClient = errors.NewServiceError("failed to enable client")
// ErrDisableClient indicates error in disabling client.
ErrDisableClient = errors.New("failed to disable client")
// ErrAddPolicies indicates error in adding policies.
ErrAddPolicies = errors.New("failed to add policies")
// ErrDeletePolicies indicates error in removing policies.
ErrDeletePolicies = errors.New("failed to remove policies")
ErrDisableClient = errors.NewServiceError("failed to disable client")
// ErrSearch indicates error in searching clients.
ErrSearch = errors.New("failed to search clients")
// ErrInvitationAlreadyRejected indicates that the invitation is already rejected.
ErrInvitationAlreadyRejected = errors.New("invitation already rejected")
ErrInvitationAlreadyRejected = errors.NewRequestError("invitation already rejected")
// ErrInvitationAlreadyAccepted indicates that the invitation is already accepted.
ErrInvitationAlreadyAccepted = errors.New("invitation already accepted")
ErrInvitationAlreadyAccepted = errors.NewRequestError("invitation already accepted")
// ErrParentGroupAuthorization indicates failure occurred while authorizing the parent group.
ErrParentGroupAuthorization = errors.New("failed to authorize parent group")
// ErrMissingUsername indicates that the user's names are missing.
ErrMissingUsername = errors.New("missing usernames")
// ErrEnableUser indicates error in enabling user.
ErrEnableUser = errors.New("failed to enable user")
ErrEnableUser = errors.NewServiceError("failed to enable user")
// ErrDisableUser indicates error in disabling user.
ErrDisableUser = errors.New("failed to disable user")
ErrDisableUser = errors.NewServiceError("failed to disable user")
// ErrRollbackRepo indicates a failure to rollback repository.
ErrRollbackRepo = errors.New("failed to rollback repo")
// ErrUnauthorizedPAT indicates failure occurred while authorizing PAT.
ErrUnauthorizedPAT = errors.New("failed to authorize PAT")
// ErrRetainOneMember indicates that at least one owner must be retained in the entity.
ErrRetainOneMember = errors.New("must retain at least one member")
// ErrSuperAdminAction indicates that the user is not a super admin.
ErrSuperAdminAction = errors.New("not authorized to perform admin action")
// ErrUserAlreadyVerified indicates user is already verified.
ErrUserAlreadyVerified = errors.New("user already verified")
// ErrInvalidUserVerification indicates user verification is invalid.
ErrInvalidUserVerification = errors.New("invalid verification")
// ErrUserVerificationExpired indicates user verification is expired.
ErrUserVerificationExpired = errors.New("verification expired, please generate new verification")
// ErrRegisterUser indicates error in register a user.
ErrRegisterUser = errors.New("failed to register user")
// ErrExternalAuthProviderCouldNotUpdate indicates that users authenticated via external provider cannot update their account details directly.
ErrExternalAuthProviderCouldNotUpdate = errors.New("account details can only be updated through your authentication provider's settings")
// ErrFailedToSaveEntityDB indicates failure to save entity to database.
ErrFailedToSaveEntityDB = errors.New("failed to save entity to database")
)
+3 -6
View File
@@ -5,7 +5,7 @@ package errors
var (
// ErrMalformedEntity indicates a malformed entity specification.
ErrMalformedEntity = New("malformed entity specification")
ErrMalformedEntity = NewRequestError("malformed entity specification")
// ErrUnsupportedContentType indicates invalid content type.
ErrUnsupportedContentType = New("invalid content type")
@@ -16,9 +16,6 @@ var (
// ErrEmptyPath indicates empty file path.
ErrEmptyPath = New("empty file path")
// ErrStatusAlreadyAssigned indicated that the client or group has already been assigned the status.
ErrStatusAlreadyAssigned = New("status already assigned")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = New("failed to rollback transaction")
@@ -35,7 +32,7 @@ var (
ErrMissingMember = New("member id is not found")
// ErrEmailAlreadyExists indicates that the email id already exists.
ErrEmailAlreadyExists = New("email id already exists")
ErrEmailAlreadyExists = New("email id already registered")
// ErrUsernameNotAvailable indicates that the username is not available.
ErrUsernameNotAvailable = New("username not available")
@@ -50,5 +47,5 @@ var (
ErrTryAgain = New("Something went wrong, please try again")
// ErrRouteNotAvailable indicates that the username is not available.
ErrRouteNotAvailable = New("route not available")
ErrRouteNotAvailable = NewRequestError("route not available")
)
+53
View File
@@ -0,0 +1,53 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"github.com/absmach/supermq/pkg/errors"
"github.com/jackc/pgx/v5/pgconn"
)
var _ errors.Handler = (*errHandler)(nil)
type errHandler struct {
duplicateErrors errors.Mapper
}
func WithDuplicateErrors(mapper errors.Mapper) errors.HandlerOption {
return func(eh *errors.Handler) {
if h, ok := (*eh).(*errHandler); ok {
h.duplicateErrors = mapper
}
}
}
func NewErrorHandler(opts ...errors.HandlerOption) errors.Handler {
var eh errors.Handler = &errHandler{}
for _, opt := range opts {
opt(&eh)
}
return eh
}
// Handle handles the error.
func (eh errHandler) HandleError(wrapper, err error) error {
pqErr, ok := err.(*pgconn.PgError)
if ok {
switch pqErr.Code {
case errDuplicate:
if eh.duplicateErrors != nil {
if knownErr, ok := eh.duplicateErrors.GetError(pqErr.ConstraintName); ok {
return errors.Wrap(wrapper, knownErr)
}
}
return errors.Wrap(wrapper, err)
case errInvalid, errInvalidChar, errTruncation, errUntranslatable:
return errors.Wrap(wrapper, err)
case errFK:
return errors.Wrap(wrapper, err)
}
}
return errors.Wrap(wrapper, err)
}
+7 -7
View File
@@ -32,7 +32,7 @@ func (d Decoder) DecodeCreateRole(_ context.Context, r *http.Request) (any, erro
entityID: chi.URLParam(r, d.entityIDTemplate),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -135,7 +135,7 @@ func (d Decoder) DecodeRemoveEntityMembers(_ context.Context, r *http.Request) (
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -159,7 +159,7 @@ func (d Decoder) DecodeUpdateRole(_ context.Context, r *http.Request) (any, erro
roleID: chi.URLParam(r, "roleID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -190,7 +190,7 @@ func (d Decoder) DecodeAddRoleActions(_ context.Context, r *http.Request) (any,
roleID: chi.URLParam(r, "roleID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -214,7 +214,7 @@ func (d Decoder) DecodeDeleteRoleActions(_ context.Context, r *http.Request) (an
roleID: chi.URLParam(r, "roleID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -238,7 +238,7 @@ func (d Decoder) DecodeAddRoleMembers(_ context.Context, r *http.Request) (any,
roleID: chi.URLParam(r, "roleID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
@@ -272,7 +272,7 @@ func (d Decoder) DecodeDeleteRoleMembers(_ context.Context, r *http.Request) (an
roleID: chi.URLParam(r, "roleID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
+32 -32
View File
@@ -132,7 +132,7 @@ func TestCreateChannel(t *testing.T) {
svcRes: []channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create channel with parent group",
@@ -214,7 +214,7 @@ func TestCreateChannel(t *testing.T) {
svcRes: []channels.Channel{iChannel},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -316,7 +316,7 @@ func TestCreateChannels(t *testing.T) {
svcRes: []channels.Channel{},
svcErr: nil,
response: []sdk.Channel{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create channels with service response that can't be unmarshalled",
@@ -334,7 +334,7 @@ func TestCreateChannels(t *testing.T) {
},
svcErr: nil,
response: []sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -493,7 +493,7 @@ func TestListChannels(t *testing.T) {
svcRes: channels.ChannelsPage{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "list channels with level",
@@ -571,7 +571,7 @@ func TestListChannels(t *testing.T) {
svcRes: channels.ChannelsPage{},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list channels with service response that can't be unmarshalled",
@@ -599,7 +599,7 @@ func TestListChannels(t *testing.T) {
},
svcErr: nil,
response: sdk.ChannelsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -709,9 +709,9 @@ func TestViewChannel(t *testing.T) {
withRoles: false,
channelID: wrongID,
svcRes: channels.Channel{},
svcErr: svcerr.ErrViewEntity,
svcErr: svcerr.ErrNotFound,
response: sdk.Channel{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
{
desc: "view channel with empty channel id",
@@ -738,7 +738,7 @@ func TestViewChannel(t *testing.T) {
},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -971,7 +971,7 @@ func TestUpdateChannel(t *testing.T) {
svcRes: channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "update channel that can't be marshalled",
@@ -988,7 +988,7 @@ func TestUpdateChannel(t *testing.T) {
svcRes: channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update channel with service response that can't be unmarshalled",
@@ -1010,7 +1010,7 @@ func TestUpdateChannel(t *testing.T) {
},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
{
desc: "update channel with empty channel id",
@@ -1157,7 +1157,7 @@ func TestUpdateChannelTags(t *testing.T) {
svcRes: channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update channel tags with a response that can't be unmarshalled",
@@ -1174,7 +1174,7 @@ func TestUpdateChannelTags(t *testing.T) {
},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1266,7 +1266,7 @@ func TestEnableChannel(t *testing.T) {
svcRes: channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "enable channel with service response that can't be unmarshalled",
@@ -1281,7 +1281,7 @@ func TestEnableChannel(t *testing.T) {
},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1376,7 +1376,7 @@ func TestDisableChannel(t *testing.T) {
svcRes: channels.Channel{},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disable channel with service response that can't be unmarshalled",
@@ -1391,7 +1391,7 @@ func TestDisableChannel(t *testing.T) {
},
svcErr: nil,
response: sdk.Channel{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1572,7 +1572,7 @@ func TestConnect(t *testing.T) {
Types: []string{"Publish", "Subscribe"},
},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "connect with empty client id",
@@ -1584,7 +1584,7 @@ func TestConnect(t *testing.T) {
Types: []string{"Publish", "Subscribe"},
},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -1679,7 +1679,7 @@ func TestDisconnect(t *testing.T) {
Types: []string{"Publish", "Subscribe"},
},
svcErr: svcerr.ErrAuthorization,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
{
desc: "disconnect with empty channel id",
@@ -1691,7 +1691,7 @@ func TestDisconnect(t *testing.T) {
Types: []string{"Publish", "Subscribe"},
},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disconnect with empty client id",
@@ -1703,7 +1703,7 @@ func TestDisconnect(t *testing.T) {
Types: []string{"Publish", "Subscribe"},
},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -1802,7 +1802,7 @@ func TestConnectClients(t *testing.T) {
clientID: clientID,
connType: "Publish",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "connect with empty client id",
@@ -1812,7 +1812,7 @@ func TestConnectClients(t *testing.T) {
clientID: "",
connType: "Publish",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -1895,7 +1895,7 @@ func TestDisconnectClients(t *testing.T) {
channelID: wrongID,
clientID: clientID,
connType: "Publish",
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
{
desc: "disconnect with empty channel id",
@@ -1905,7 +1905,7 @@ func TestDisconnectClients(t *testing.T) {
clientID: clientID,
connType: "Publish",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disconnect with empty client id",
@@ -1915,7 +1915,7 @@ func TestDisconnectClients(t *testing.T) {
clientID: "",
connType: "Publish",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -2002,7 +2002,7 @@ func TestSetChannelParent(t *testing.T) {
channelID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "set channel parent with empty parent id",
@@ -2011,7 +2011,7 @@ func TestSetChannelParent(t *testing.T) {
channelID: channel.ID,
parentID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingParentGroupID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingParentGroupID, http.StatusBadRequest),
},
}
@@ -2097,7 +2097,7 @@ func TestRemoveChannelParent(t *testing.T) {
channelID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
+51 -51
View File
@@ -136,7 +136,7 @@ func TestCreateClient(t *testing.T) {
svcRes: []clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "create a client with invalid id",
@@ -154,7 +154,7 @@ func TestCreateClient(t *testing.T) {
svcRes: []clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
{
desc: "create a client with a request that can't be marshalled",
@@ -170,7 +170,7 @@ func TestCreateClient(t *testing.T) {
svcRes: []clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create a client with a response that can't be unmarshalled",
@@ -188,7 +188,7 @@ func TestCreateClient(t *testing.T) {
}},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -281,7 +281,7 @@ func TestCreateClients(t *testing.T) {
svcRes: []clients.Client{},
svcErr: nil,
response: []sdk.Client{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create new clients with a response that can't be unmarshalled",
@@ -299,7 +299,7 @@ func TestCreateClients(t *testing.T) {
}},
svcErr: nil,
response: []sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -422,7 +422,7 @@ func TestListClients(t *testing.T) {
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "list all clients with name size greater than max",
@@ -441,7 +441,7 @@ func TestListClients(t *testing.T) {
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "list all clients with status",
@@ -532,7 +532,7 @@ func TestListClients(t *testing.T) {
svcRes: clients.ClientsPage{},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list all clients with response that can't be unmarshalled",
@@ -566,7 +566,7 @@ func TestListClients(t *testing.T) {
},
svcErr: nil,
response: sdk.ClientsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -670,9 +670,9 @@ func TestViewClient(t *testing.T) {
withRoles: false,
clientID: wrongID,
svcRes: clients.Client{},
svcErr: svcerr.ErrViewEntity,
svcErr: svcerr.ErrNotFound,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
{
desc: "view client with empty client id",
@@ -701,7 +701,7 @@ func TestViewClient(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -852,7 +852,7 @@ func TestUpdateClient(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update client with a response that can't be unmarshalled",
@@ -870,7 +870,7 @@ func TestUpdateClient(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -988,7 +988,7 @@ func TestUpdateClientTags(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update client tags with a request that can't be marshalled",
@@ -1004,7 +1004,7 @@ func TestUpdateClientTags(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update client tags with a response that can't be unmarshalled",
@@ -1022,7 +1022,7 @@ func TestUpdateClientTags(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1125,7 +1125,7 @@ func TestUpdateClientSecret(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update client with empty new secret",
@@ -1136,7 +1136,7 @@ func TestUpdateClientSecret(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingSecret), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingSecret, http.StatusBadRequest),
},
{
desc: "update client secret with a response that can't be unmarshalled",
@@ -1154,7 +1154,7 @@ func TestUpdateClientSecret(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1240,7 +1240,7 @@ func TestEnableClient(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "enable client with a response that can't be unmarshalled",
@@ -1257,7 +1257,7 @@ func TestEnableClient(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1333,7 +1333,7 @@ func TestDisableClient(t *testing.T) {
svcRes: clients.Client{},
svcErr: svcerr.ErrDisableClient,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrDisableClient, http.StatusInternalServerError),
err: errors.NewSDKErrorWithStatus(svcerr.ErrDisableClient, http.StatusUnprocessableEntity),
},
{
desc: "disable client with empty client id",
@@ -1343,7 +1343,7 @@ func TestDisableClient(t *testing.T) {
svcRes: clients.Client{},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disable client with a response that can't be unmarshalled",
@@ -1360,7 +1360,7 @@ func TestDisableClient(t *testing.T) {
},
svcErr: nil,
response: sdk.Client{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1527,7 +1527,7 @@ func TestSetClientParent(t *testing.T) {
clientID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "set client parent with empty parent id",
@@ -1536,7 +1536,7 @@ func TestSetClientParent(t *testing.T) {
clientID: clientID,
parentID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingParentGroupID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingParentGroupID, http.StatusBadRequest),
},
}
@@ -1622,7 +1622,7 @@ func TestRemoveClientParent(t *testing.T) {
clientID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -1741,7 +1741,7 @@ func TestCreateClientRole(t *testing.T) {
svcRes: roles.RoleProvision{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
{
desc: "create client role with empty role name",
@@ -1756,7 +1756,7 @@ func TestCreateClientRole(t *testing.T) {
svcRes: roles.RoleProvision{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
},
}
@@ -1887,7 +1887,7 @@ func TestListClientRoles(t *testing.T) {
svcRes: roles.RolePage{},
svcErr: nil,
response: sdk.RolesPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -1993,7 +1993,7 @@ func TestViewClientRole(t *testing.T) {
svcRes: roles.Role{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "view client role with invalid role id",
@@ -2122,7 +2122,7 @@ func TestUpdateClientRole(t *testing.T) {
svcRes: roles.Role{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -2209,7 +2209,7 @@ func TestDeleteClientRole(t *testing.T) {
domainID: domainID,
clientID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "delete client role with invalid role id",
@@ -2319,7 +2319,7 @@ func TestAddClientRoleActions(t *testing.T) {
roleID: roleID,
actions: actions,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "add client role actions with invalid role id",
@@ -2341,7 +2341,7 @@ func TestAddClientRoleActions(t *testing.T) {
actions: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -2433,7 +2433,7 @@ func TestListClientRoleActions(t *testing.T) {
domainID: domainID,
clientID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "list client role actions with invalid role id",
@@ -2451,7 +2451,7 @@ func TestListClientRoleActions(t *testing.T) {
clientID: clientID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -2545,7 +2545,7 @@ func TestRemoveClientRoleActions(t *testing.T) {
clientID: "",
roleID: roleID,
actions: actions,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove client role actions with invalid role id",
@@ -2565,7 +2565,7 @@ func TestRemoveClientRoleActions(t *testing.T) {
roleID: roleID,
actions: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -2651,7 +2651,7 @@ func TestRemoveAllClientRoleActions(t *testing.T) {
domainID: domainID,
clientID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove all client role actions with invalid role id",
@@ -2669,7 +2669,7 @@ func TestRemoveAllClientRoleActions(t *testing.T) {
clientID: clientID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -2770,7 +2770,7 @@ func TestAddClientRoleMembers(t *testing.T) {
roleID: roleID,
members: members,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "add client role members with invalid role id",
@@ -2792,7 +2792,7 @@ func TestAddClientRoleMembers(t *testing.T) {
members: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -2915,7 +2915,7 @@ func TestListClientRoleMembers(t *testing.T) {
},
clientID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "list client role members with invalid role id",
@@ -2941,7 +2941,7 @@ func TestListClientRoleMembers(t *testing.T) {
},
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -3035,7 +3035,7 @@ func TestRemoveClientRoleMembers(t *testing.T) {
clientID: "",
roleID: roleID,
members: members,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove client role members with invalid role id",
@@ -3055,7 +3055,7 @@ func TestRemoveClientRoleMembers(t *testing.T) {
roleID: roleID,
members: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -3141,7 +3141,7 @@ func TestRemoveAllClientRoleMembers(t *testing.T) {
domainID: domainID,
clientID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove all client role members with invalid role id",
@@ -3159,7 +3159,7 @@ func TestRemoveAllClientRoleMembers(t *testing.T) {
clientID: clientID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
+16 -16
View File
@@ -149,7 +149,7 @@ func TestCreateDomain(t *testing.T) {
svcRes: domains.Domain{},
svcErr: nil,
response: sdk.Domain{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create domain with response that cannot be unmarshalled",
@@ -165,7 +165,7 @@ func TestCreateDomain(t *testing.T) {
},
svcErr: nil,
response: sdk.Domain{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -294,7 +294,7 @@ func TestUpdateDomain(t *testing.T) {
svcRes: domains.Domain{},
svcErr: nil,
response: sdk.Domain{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update domain with response that cannot be unmarshalled",
@@ -313,7 +313,7 @@ func TestUpdateDomain(t *testing.T) {
},
svcErr: nil,
response: sdk.Domain{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -439,7 +439,7 @@ func TestViewDomain(t *testing.T) {
},
svcErr: nil,
response: sdk.Domain{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -568,7 +568,7 @@ func TestListDomians(t *testing.T) {
svcRes: domains.DomainsPage{},
svcErr: nil,
response: sdk.DomainsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list domains with request that cannot be marshalled",
@@ -592,7 +592,7 @@ func TestListDomians(t *testing.T) {
},
svcErr: nil,
response: sdk.DomainsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -938,7 +938,7 @@ func TestCreateDomainRole(t *testing.T) {
svcRes: roles.RoleProvision{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
},
}
@@ -1487,7 +1487,7 @@ func TestAddDomainRoleActions(t *testing.T) {
actions: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -1589,7 +1589,7 @@ func TestListDomainRoleActions(t *testing.T) {
domainID: domainID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -1695,7 +1695,7 @@ func TestRemoveDomainRoleActions(t *testing.T) {
roleID: roleID,
actions: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -1791,7 +1791,7 @@ func TestRemoveAllDomainRoleActions(t *testing.T) {
domainID: domainID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -1906,7 +1906,7 @@ func TestAddDomainRoleMembers(t *testing.T) {
members: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -2047,7 +2047,7 @@ func TestListDomainRoleMembers(t *testing.T) {
},
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -2153,7 +2153,7 @@ func TestRemoveDomainRoleMembers(t *testing.T) {
roleID: roleID,
members: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -2249,7 +2249,7 @@ func TestRemoveAllDomainRoleMembers(t *testing.T) {
domainID: domainID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
+52 -52
View File
@@ -211,7 +211,7 @@ func TestCreateGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "create group with name that is too long",
@@ -226,7 +226,7 @@ func TestCreateGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "create group with request that cannot be marshalled",
@@ -243,7 +243,7 @@ func TestCreateGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "create group with service response that cannot be unmarshalled",
@@ -262,7 +262,7 @@ func TestCreateGroup(t *testing.T) {
svcRes: uGroup,
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -425,7 +425,7 @@ func TestListGroups(t *testing.T) {
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "list groups with given name",
@@ -480,7 +480,7 @@ func TestListGroups(t *testing.T) {
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list groups with service response that cannot be unmarshalled",
@@ -513,7 +513,7 @@ func TestListGroups(t *testing.T) {
},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -615,9 +615,9 @@ func TestViewGroup(t *testing.T) {
withRoles: false,
groupID: wrongID,
svcRes: groups.Group{},
svcErr: svcerr.ErrViewEntity,
svcErr: svcerr.ErrNotFound,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
{
desc: "view group with service response that cannot be unmarshalled",
@@ -634,7 +634,7 @@ func TestViewGroup(t *testing.T) {
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
{
desc: "view group with empty id",
@@ -822,7 +822,7 @@ func TestUpdateGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update group with service response that cannot be unmarshalled",
@@ -849,7 +849,7 @@ func TestUpdateGroup(t *testing.T) {
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -984,7 +984,7 @@ func TestUpdateGroupTags(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update group tags with a response that can't be unmarshalled",
@@ -1001,7 +1001,7 @@ func TestUpdateGroupTags(t *testing.T) {
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1096,7 +1096,7 @@ func TestEnableGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "enable group with service response that cannot be unmarshalled",
@@ -1112,7 +1112,7 @@ func TestEnableGroup(t *testing.T) {
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1207,7 +1207,7 @@ func TestDisableGroup(t *testing.T) {
svcRes: groups.Group{},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disable group with service response that cannot be unmarshalled",
@@ -1223,7 +1223,7 @@ func TestDisableGroup(t *testing.T) {
},
svcErr: nil,
response: sdk.Group{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1389,7 +1389,7 @@ func TestSetGroupParent(t *testing.T) {
groupID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "set group parent with empty parent id",
@@ -1398,7 +1398,7 @@ func TestSetGroupParent(t *testing.T) {
groupID: groupID,
parentID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
}
@@ -1485,7 +1485,7 @@ func TestRemoveGroupParent(t *testing.T) {
groupID: "",
parentID: parentID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -1572,7 +1572,7 @@ func TestAddChildrenGroups(t *testing.T) {
groupID: "",
childrenIDs: []string{childID},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "add children group with empty children ids",
@@ -1581,7 +1581,7 @@ func TestAddChildrenGroups(t *testing.T) {
groupID: groupID,
childrenIDs: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingChildrenGroupIDs), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingChildrenGroupIDs, http.StatusBadRequest),
},
}
@@ -1668,7 +1668,7 @@ func TestRemoveChildrenGroups(t *testing.T) {
groupID: "",
childrenIDs: []string{childID},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove children group with empty children ids",
@@ -1677,7 +1677,7 @@ func TestRemoveChildrenGroups(t *testing.T) {
groupID: groupID,
childrenIDs: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingChildrenGroupIDs), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingChildrenGroupIDs, http.StatusBadRequest),
},
}
@@ -1757,7 +1757,7 @@ func TestRemoveAllChildrenGroups(t *testing.T) {
token: validToken,
groupID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -1925,7 +1925,7 @@ func TestListChildrenGroups(t *testing.T) {
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "list children groups with given metadata",
@@ -1978,7 +1978,7 @@ func TestListChildrenGroups(t *testing.T) {
svcRes: groups.Page{},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list children groups with service response that cannot be unmarshalled",
@@ -2010,7 +2010,7 @@ func TestListChildrenGroups(t *testing.T) {
},
svcErr: nil,
response: sdk.GroupsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -2186,7 +2186,7 @@ func TestHierarchy(t *testing.T) {
},
svcErr: nil,
response: sdk.GroupsHierarchyPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -2307,7 +2307,7 @@ func TestCreateGroupRole(t *testing.T) {
svcRes: roles.RoleProvision{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidIDFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidIDFormat, http.StatusBadRequest),
},
{
desc: "create group role with empty role name",
@@ -2322,7 +2322,7 @@ func TestCreateGroupRole(t *testing.T) {
svcRes: roles.RoleProvision{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleName), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleName, http.StatusBadRequest),
},
}
@@ -2454,7 +2454,7 @@ func TestListGroupRoles(t *testing.T) {
svcRes: roles.RolePage{},
svcErr: nil,
response: sdk.RolesPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -2560,7 +2560,7 @@ func TestViewGroupRole(t *testing.T) {
svcRes: roles.Role{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "view group role with invalid role id",
@@ -2690,7 +2690,7 @@ func TestUpdateGroupRole(t *testing.T) {
svcRes: roles.Role{},
svcErr: nil,
response: sdk.Role{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
}
@@ -2778,7 +2778,7 @@ func TestDeleteGroupRole(t *testing.T) {
domainID: domainID,
groupID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "delete group role with invalid role id",
@@ -2889,7 +2889,7 @@ func TestAddGroupRoleActions(t *testing.T) {
roleID: roleID,
actions: actions,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "add group role actions with invalid role id",
@@ -2911,7 +2911,7 @@ func TestAddGroupRoleActions(t *testing.T) {
actions: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -3004,7 +3004,7 @@ func TestListGroupRoleActions(t *testing.T) {
domainID: domainID,
groupID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "list group role actions with invalid role id",
@@ -3022,7 +3022,7 @@ func TestListGroupRoleActions(t *testing.T) {
groupID: groupID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -3117,7 +3117,7 @@ func TestRemoveGroupRoleActions(t *testing.T) {
groupID: "",
roleID: roleID,
actions: actions,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove group role actions with invalid role id",
@@ -3137,7 +3137,7 @@ func TestRemoveGroupRoleActions(t *testing.T) {
roleID: roleID,
actions: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPolicyEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPolicyEntityType, http.StatusBadRequest),
},
}
@@ -3224,7 +3224,7 @@ func TestRemoveAllGroupRoleActions(t *testing.T) {
domainID: domainID,
groupID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove all group role actions with invalid role id",
@@ -3242,7 +3242,7 @@ func TestRemoveAllGroupRoleActions(t *testing.T) {
groupID: groupID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -3344,7 +3344,7 @@ func TestAddGroupRoleMembers(t *testing.T) {
roleID: roleID,
members: members,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "add group role members with invalid role id",
@@ -3366,7 +3366,7 @@ func TestAddGroupRoleMembers(t *testing.T) {
members: []string{},
svcErr: nil,
response: []string{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -3490,7 +3490,7 @@ func TestListGroupRoleMembers(t *testing.T) {
},
groupID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "list group role members with invalid role id",
@@ -3516,7 +3516,7 @@ func TestListGroupRoleMembers(t *testing.T) {
},
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
@@ -3611,7 +3611,7 @@ func TestRemoveGroupRoleMembers(t *testing.T) {
groupID: "",
roleID: roleID,
members: members,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove group role members with invalid role id",
@@ -3631,7 +3631,7 @@ func TestRemoveGroupRoleMembers(t *testing.T) {
roleID: roleID,
members: []string{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleMembers), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleMembers, http.StatusBadRequest),
},
}
@@ -3718,7 +3718,7 @@ func TestRemoveAllGroupRoleMembers(t *testing.T) {
domainID: domainID,
groupID: "",
roleID: roleID,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "remove all group role members with invalid role id",
@@ -3736,7 +3736,7 @@ func TestRemoveAllGroupRoleMembers(t *testing.T) {
groupID: groupID,
roleID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingRoleID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingRoleID, http.StatusBadRequest),
},
}
+3 -3
View File
@@ -85,7 +85,7 @@ func TestSendInvitation(t *testing.T) {
},
svcReq: domains.Invitation{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "send invitation with empty role ID",
@@ -97,7 +97,7 @@ func TestSendInvitation(t *testing.T) {
},
svcReq: domains.Invitation{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "send inviation with invalid domainID",
@@ -218,7 +218,7 @@ func TestListInvitation(t *testing.T) {
svcRes: domains.InvitationPage{},
svcErr: nil,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
}
for _, tc := range cases {
+4 -4
View File
@@ -225,7 +225,7 @@ func TestRetrieveJournal(t *testing.T) {
svcRes: journal.JournalsPage{},
svcErr: nil,
response: sdk.JournalsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidEntityType), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidEntityType, http.StatusBadRequest),
},
{
desc: "retrieve journal with empty entity ID",
@@ -273,7 +273,7 @@ func TestRetrieveJournal(t *testing.T) {
svcRes: journal.JournalsPage{},
svcErr: nil,
response: sdk.JournalsPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "retrieve journal with invalid page metadata",
@@ -292,7 +292,7 @@ func TestRetrieveJournal(t *testing.T) {
svcRes: journal.JournalsPage{},
svcErr: nil,
response: sdk.JournalsPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "retrieve journal with response that cannot be unmarshalled",
@@ -325,7 +325,7 @@ func TestRetrieveJournal(t *testing.T) {
},
svcErr: nil,
response: sdk.JournalsPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
+1 -1
View File
@@ -153,7 +153,7 @@ func TestSendMessage(t *testing.T) {
authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""},
authErr: nil,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptyMessage), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrEmptyMessage, http.StatusBadRequest),
},
{
desc: "publish message with channel subtopic",
+2 -2
View File
@@ -85,7 +85,7 @@ func TestIssueToken(t *testing.T) {
svcRes: &grpcTokenV1.Token{},
svcErr: nil,
response: sdk.Token{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsernameEmail), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsernameEmail, http.StatusBadRequest),
},
{
desc: "issue token with empty secret",
@@ -96,7 +96,7 @@ func TestIssueToken(t *testing.T) {
svcRes: &grpcTokenV1.Token{},
svcErr: nil,
response: sdk.Token{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
},
}
for _, tc := range cases {
+45 -45
View File
@@ -136,7 +136,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsername), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsername, http.StatusBadRequest),
},
{
desc: "register user with first name too long",
@@ -151,7 +151,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrNameSize, http.StatusBadRequest),
},
{
desc: "register user with empty userName",
@@ -171,7 +171,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingUsername), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingUsername, http.StatusBadRequest),
},
{
desc: "register user with empty secret",
@@ -191,7 +191,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
},
{
desc: "register user with secret that is too short",
@@ -211,7 +211,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
},
{
desc: "register a user with request that can't be marshalled",
@@ -232,7 +232,7 @@ func TestCreateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "register a user with response that can't be unmarshalled",
@@ -254,7 +254,7 @@ func TestCreateUser(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -412,7 +412,7 @@ func TestListUsers(t *testing.T) {
svcRes: users.UsersPage{},
svcErr: nil,
response: sdk.UsersPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
{
desc: "list users with given metadata",
@@ -523,7 +523,7 @@ func TestListUsers(t *testing.T) {
svcRes: users.UsersPage{},
svcErr: nil,
response: sdk.UsersPage{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "list users with response that can't be unmarshalled",
@@ -553,7 +553,7 @@ func TestListUsers(t *testing.T) {
},
},
response: sdk.UsersPage{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -671,7 +671,7 @@ func TestSearchUsers(t *testing.T) {
Limit: limit,
FirstName: "",
},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptySearchQuery), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrEmptySearchQuery, http.StatusBadRequest),
},
{
desc: "search for users with invalid length of query",
@@ -681,7 +681,7 @@ func TestSearchUsers(t *testing.T) {
Limit: limit,
Username: "a",
},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrLenSearchQuery, apiutil.ErrValidation), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrValidation, http.StatusBadRequest),
},
{
desc: "search for users with invalid limit",
@@ -691,7 +691,7 @@ func TestSearchUsers(t *testing.T) {
Limit: 0,
Username: "user_10",
},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrLimitSize, http.StatusBadRequest),
},
}
@@ -761,9 +761,9 @@ func TestViewUser(t *testing.T) {
token: validToken,
userID: wrongID,
svcRes: users.User{},
svcErr: svcerr.ErrViewEntity,
svcErr: svcerr.ErrNotFound,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
{
desc: "view user with empty id",
@@ -788,7 +788,7 @@ func TestViewUser(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -867,7 +867,7 @@ func TestUserProfile(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1009,7 +1009,7 @@ func TestUpdateUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update user with response that can't be unmarshalled",
@@ -1031,7 +1031,7 @@ func TestUpdateUser(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
@@ -1155,7 +1155,7 @@ func TestUpdateUserTags(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update user tags with request that can't be marshalled",
@@ -1170,7 +1170,7 @@ func TestUpdateUserTags(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update user tags with response that can't be unmarshalled",
@@ -1192,7 +1192,7 @@ func TestUpdateUserTags(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1318,7 +1318,7 @@ func TestUpdateUserEmail(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update email with response that can't be unmarshalled",
@@ -1340,7 +1340,7 @@ func TestUpdateUserEmail(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1398,9 +1398,9 @@ func TestResetPasswordRequest(t *testing.T) {
desc: "reset password request with invalid email",
email: "invalidemail",
svcRes: users.User{},
svcErr: svcerr.ErrViewEntity,
svcErr: svcerr.ErrNotFound,
issueRes: &grpcTokenV1.Token{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrViewEntity, http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
{
desc: "reset password request with empty email",
@@ -1408,7 +1408,7 @@ func TestResetPasswordRequest(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
issueRes: &grpcTokenV1.Token{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingEmail), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingEmail, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -1478,7 +1478,7 @@ func TestResetPassword(t *testing.T) {
newPassword: "",
confPassword: newPassword,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
},
{
desc: "reset password with empty confirm password",
@@ -1487,7 +1487,7 @@ func TestResetPassword(t *testing.T) {
newPassword: newPassword,
confPassword: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingConfPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingConfPass, http.StatusBadRequest),
},
{
desc: "reset password with new password not matching confirm password",
@@ -1496,7 +1496,7 @@ func TestResetPassword(t *testing.T) {
newPassword: newPassword,
confPassword: "wrongPassword",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidResetPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidResetPass, http.StatusBadRequest),
},
{
desc: "reset password with weak password",
@@ -1505,7 +1505,7 @@ func TestResetPassword(t *testing.T) {
newPassword: "weak",
confPassword: "weak",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
},
}
for _, tc := range cases {
@@ -1590,7 +1590,7 @@ func TestUpdatePassword(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
},
{
desc: "update password with empty new password",
@@ -1600,7 +1600,7 @@ func TestUpdatePassword(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingPass), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingPass, http.StatusBadRequest),
},
{
desc: "update password with invalid new password",
@@ -1610,7 +1610,7 @@ func TestUpdatePassword(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrPasswordFormat), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrPasswordFormat, http.StatusBadRequest),
},
{
desc: "update password with invalid old password",
@@ -1636,7 +1636,7 @@ func TestUpdatePassword(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1757,7 +1757,7 @@ func TestUpdateUserRole(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update user role with request that can't be marshalled",
@@ -1772,7 +1772,7 @@ func TestUpdateUserRole(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update user role with response that can't be unmarshalled",
@@ -1794,7 +1794,7 @@ func TestUpdateUserRole(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -1930,7 +1930,7 @@ func TestUpdateUsername(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update username with response that can't be unmarshalled",
@@ -1958,7 +1958,7 @@ func TestUpdateUsername(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -2085,7 +2085,7 @@ func TestUpdateProfilePicture(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "update profile picture with request that can't be marshalled",
@@ -2100,7 +2100,7 @@ func TestUpdateProfilePicture(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("json: unsupported type: chan int")),
err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")),
},
{
desc: "update profile picture with response that can't be unmarshalled",
@@ -2121,7 +2121,7 @@ func TestUpdateProfilePicture(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
@@ -2284,7 +2284,7 @@ func TestDisableUser(t *testing.T) {
svcRes: users.User{},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest),
},
{
desc: "disable user with response that can't be unmarshalled",
@@ -2299,7 +2299,7 @@ func TestDisableUser(t *testing.T) {
},
svcErr: nil,
response: sdk.User{},
err: errors.NewSDKError(errors.New("unexpected end of JSON input")),
err: errors.NewSDKError(fmt.Errorf("unexpected end of JSON input")),
},
}
for _, tc := range cases {
+139 -175
View File
@@ -135,30 +135,8 @@ func TestRegister(t *testing.T) {
user: user,
token: validToken,
contentType: contentType,
status: http.StatusConflict,
err: svcerr.ErrConflict,
},
{
desc: "register a new user with an empty token",
user: user,
token: "",
contentType: contentType,
status: http.StatusUnauthorized,
err: apiutil.ErrBearerToken,
},
{
desc: "register a user with an invalid ID",
user: users.User{
ID: inValid,
Email: "user@example.com",
Credentials: users.Credentials{
Secret: "12345678",
},
},
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: svcerr.ErrConflict,
},
{
desc: "register a user that can't be marshalled",
@@ -174,7 +152,7 @@ func TestRegister(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "register user with invalid status",
@@ -206,7 +184,7 @@ func TestRegister(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrNameSize,
},
{
desc: "register user with invalid content type",
@@ -214,7 +192,7 @@ func TestRegister(t *testing.T) {
token: validToken,
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "register user with empty request body",
@@ -222,7 +200,7 @@ func TestRegister(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingFirstName,
},
{
desc: "register user with invalid username",
@@ -239,7 +217,7 @@ func TestRegister(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidUsername,
},
}
@@ -323,7 +301,7 @@ func TestView(t *testing.T) {
desc: "view user with invalid ID",
token: validToken,
id: inValid,
status: http.StatusBadRequest,
status: http.StatusUnprocessableEntity,
authnRes: verifiedSession,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
@@ -401,7 +379,7 @@ func TestViewProfile(t *testing.T) {
desc: "view profile with service error",
token: validToken,
id: user.ID,
status: http.StatusBadRequest,
status: http.StatusUnprocessableEntity,
authnRes: verifiedSession,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
@@ -499,7 +477,7 @@ func TestListUsers(t *testing.T) {
query: "offset=invalid",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with limit",
@@ -522,7 +500,7 @@ func TestListUsers(t *testing.T) {
query: "limit=invalid",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with limit greater than max",
@@ -530,7 +508,7 @@ func TestListUsers(t *testing.T) {
query: fmt.Sprintf("limit=%d", api.MaxLimitSize+1),
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: apiutil.ErrLimitSize,
},
{
desc: "list users with name",
@@ -574,7 +552,7 @@ func TestListUsers(t *testing.T) {
query: "status=invalid",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: svcerr.ErrInvalidStatus,
},
{
desc: "list users with duplicate status",
@@ -626,7 +604,7 @@ func TestListUsers(t *testing.T) {
query: "metadata=invalid",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with duplicate metadata",
@@ -658,20 +636,6 @@ func TestListUsers(t *testing.T) {
authnRes: verifiedSession,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with list perms",
token: validToken,
listUsersResponse: users.UsersPage{
Page: users.Page{
Total: 1,
},
Users: []users.User{user},
},
query: "list_perms=true",
status: http.StatusOK,
authnRes: verifiedSession,
err: nil,
},
{
desc: "list users with duplicate list perms",
token: validToken,
@@ -702,14 +666,6 @@ func TestListUsers(t *testing.T) {
authnRes: verifiedSession,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with duplicate list perms",
token: validToken,
query: "list_perms=true&list_perms=true",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list users with email",
token: validToken,
@@ -764,7 +720,7 @@ func TestListUsers(t *testing.T) {
query: "dir=invalid",
status: http.StatusBadRequest,
authnRes: verifiedSession,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidDirection,
},
{
desc: "list users with duplicate order direction",
@@ -908,7 +864,7 @@ func TestSearchUsers(t *testing.T) {
desc: "serach users with service error",
token: validToken,
query: "username=username",
status: http.StatusBadRequest,
status: http.StatusUnprocessableEntity,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
},
@@ -1028,7 +984,7 @@ func TestUpdate(t *testing.T) {
authnRes: verifiedSession,
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user with malformed data",
@@ -1038,7 +994,7 @@ func TestUpdate(t *testing.T) {
authnRes: verifiedSession,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update user with empty id",
@@ -1047,8 +1003,8 @@ func TestUpdate(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
status: http.StatusUnprocessableEntity,
err: svcerr.ErrViewEntity,
},
}
@@ -1167,7 +1123,7 @@ func TestUpdateTags(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user tags with empty id",
@@ -1177,7 +1133,7 @@ func TestUpdateTags(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingID,
},
{
desc: "update user with malfomed data",
@@ -1187,7 +1143,7 @@ func TestUpdateTags(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
}
@@ -1334,7 +1290,7 @@ func TestUpdateEmail(t *testing.T) {
contentType: "application/xml",
token: validToken,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user email with malformed data",
@@ -1349,7 +1305,7 @@ func TestUpdateEmail(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update user email with service error",
@@ -1368,29 +1324,31 @@ func TestUpdateEmail(t *testing.T) {
}
for _, tc := range cases {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/email", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/email", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateEmail", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateEmail", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
@@ -1486,7 +1444,7 @@ func TestUpdateUsername(t *testing.T) {
contentType: "application/xml",
token: validToken,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user email with malformed data",
@@ -1501,7 +1459,7 @@ func TestUpdateUsername(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update username with invalid username",
@@ -1521,29 +1479,31 @@ func TestUpdateUsername(t *testing.T) {
}
for _, tc := range cases {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/username", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/username", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateUsername", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateUsername", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.user, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
@@ -1624,7 +1584,7 @@ func TestUpdateProfilePicture(t *testing.T) {
contentType: "application/xml",
token: validToken,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update profile picture with malformed data",
@@ -1634,7 +1594,7 @@ func TestUpdateProfilePicture(t *testing.T) {
token: validToken,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update profile picture with failed to update",
@@ -1652,29 +1612,31 @@ func TestUpdateProfilePicture(t *testing.T) {
}
for _, tc := range cases {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/picture", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/users/%s/picture", us.URL, tc.user.ID),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(tc.data),
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateProfilePicture", mock.Anything, tc.authnRes, tc.user.ID, mock.Anything).Return(tc.user, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("UpdateProfilePicture", mock.Anything, tc.authnRes, tc.user.ID, mock.Anything).Return(tc.user, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
@@ -1868,14 +1830,14 @@ func TestVerifyEmail(t *testing.T) {
desc: "verify email with empty token",
token: "",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrInvalidVerification,
},
{
desc: "verify email with service error",
token: validToken,
status: http.StatusBadRequest,
svcErr: svcerr.ErrMalformedEntity,
err: svcerr.ErrMalformedEntity,
status: http.StatusUnprocessableEntity,
svcErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
}
@@ -2101,7 +2063,7 @@ func TestUpdateRole(t *testing.T) {
authnRes: verifiedSession,
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user with malformed data",
@@ -2111,7 +2073,7 @@ func TestUpdateRole(t *testing.T) {
authnRes: verifiedSession,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "update user with service error",
@@ -2251,7 +2213,7 @@ func TestUpdateSecret(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update user secret with malformed data",
@@ -2267,7 +2229,7 @@ func TestUpdateSecret(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
}
@@ -2326,14 +2288,14 @@ func TestIssueToken(t *testing.T) {
data: fmt.Sprintf(dataFormat, "", secret),
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingUsernameEmail,
},
{
desc: "issue token with empty secret",
data: fmt.Sprintf(dataFormat, validUsername, ""),
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMissingPass,
},
{
desc: "issue token with invalid email",
@@ -2347,14 +2309,14 @@ func TestIssueToken(t *testing.T) {
data: fmt.Sprintf(dataFormat, validUsername, secret),
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "issue token with invalid contentype",
data: fmt.Sprintf(dataFormat, "invalid", secret),
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
}
@@ -2443,7 +2405,7 @@ func TestRefreshToken(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "refresh token with invalid contentype",
@@ -2452,35 +2414,37 @@ func TestRefreshToken(t *testing.T) {
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrValidation,
err: apiutil.ErrUnsupportedContentType,
},
}
for _, tc := range cases {
req := testRequest{
user: us.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/users/tokens/refresh", us.URL),
contentType: tc.contentType,
body: strings.NewReader(tc.data),
token: tc.token,
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("RefreshToken", mock.Anything, tc.authnRes, tc.token, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if tc.err != nil {
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/users/tokens/refresh", us.URL),
contentType: tc.contentType,
body: strings.NewReader(tc.data),
token: tc.token,
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("RefreshToken", mock.Anything, tc.authnRes, tc.token, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if tc.err != nil {
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
-1
View File
@@ -17,7 +17,6 @@ import (
const (
valid = "valid"
invalid = "invalid"
secret = "QJg58*aMan7j"
name = "user"
validEmail = "example@domain.com"
+11 -11
View File
@@ -384,7 +384,7 @@ func decodeUpdateUser(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -399,7 +399,7 @@ func decodeUpdateUserTags(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -414,7 +414,7 @@ func decodeUpdateUserEmail(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -427,7 +427,7 @@ func decodeUpdateUserSecret(_ context.Context, r *http.Request) (any, error) {
req := updateUserSecretReq{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -442,7 +442,7 @@ func decodeUpdateUsername(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -458,7 +458,7 @@ func decodeUpdateUserProfilePicture(_ context.Context, r *http.Request) (any, er
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -471,7 +471,7 @@ func decodePasswordResetRequest(_ context.Context, r *http.Request) (any, error)
var req passResetReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -484,7 +484,7 @@ func decodePasswordReset(_ context.Context, r *http.Request) (any, error) {
var req resetTokenReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -499,7 +499,7 @@ func decodeUpdateUserRole(_ context.Context, r *http.Request) (any, error) {
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
var err error
req.role, err = users.ToRole(req.Role)
@@ -513,7 +513,7 @@ func decodeCredentials(_ context.Context, r *http.Request) (any, error) {
req := loginUserReq{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
@@ -535,7 +535,7 @@ func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) {
var req createUserReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
+2 -2
View File
@@ -12,8 +12,8 @@ import (
const cost int = 10
var (
errHashPassword = errors.New("generate hash from password failed")
errComparePassword = errors.New("compare hash and password failed")
errHashPassword = errors.NewServiceError("generate hash from password failed")
errComparePassword = errors.NewServiceError("compare hash and password failed")
)
var _ users.Hasher = (*bcryptHasher)(nil)
+26
View File
@@ -0,0 +1,26 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "clients_email_key":
return errors.NewRequestError("email id already registered"), true
case "clients_username_key":
return errors.NewRequestError("username not available"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+39 -49
View File
@@ -18,18 +18,20 @@ import (
"github.com/absmach/supermq/pkg/postgres"
"github.com/absmach/supermq/users"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v5/pgconn"
)
var pgDuplicateErrCode = "23505"
type userRepo struct {
Repository users.UserRepository
eh errors.Handler
}
func NewRepository(db postgres.Database) users.Repository {
errHandlerOptions := []errors.HandlerOption{
postgres.WithDuplicateErrors(NewDuplicateErrors()),
}
return &userRepo{
Repository: users.UserRepository{DB: db},
eh: postgres.NewErrorHandler(errHandlerOptions...),
}
}
@@ -40,12 +42,12 @@ func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error
dbu, err := toDBUser(c)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrCreateEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
if err != nil {
return users.User{}, handleSaveError(repoerr.ErrCreateEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
@@ -54,40 +56,28 @@ func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error
dbu = DBUser{}
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrFailedOpDB, err)
}
user, err := ToUser(dbu)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
return user, nil
}
func handleSaveError(wrapper, err error) error {
if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgDuplicateErrCode {
switch pqErr.ConstraintName {
case "clients_email_key":
return errors.ErrEmailAlreadyExists
case "clients_username_key":
return errors.ErrUsernameNotAvailable
}
}
return postgres.HandleError(wrapper, err)
}
func (repo *userRepo) CheckSuperAdmin(ctx context.Context, adminID string) error {
q := "SELECT 1 FROM users WHERE id = $1 AND role = $2"
rows, err := repo.Repository.DB.QueryContext(ctx, q, adminID, users.AdminRole)
if err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
return repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
if rows.Next() {
if err := rows.Err(); err != nil {
return postgres.HandleError(repoerr.ErrViewEntity, err)
return repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return nil
}
@@ -105,7 +95,7 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
if err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -115,12 +105,12 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
}
if err = rows.StructScan(&dbu); err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
user, err := ToUser(dbu)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
return user, nil
@@ -129,7 +119,7 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.UsersPage, error) {
query, err := PageQuery(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
}
squery := applyOrdering(query, pm)
@@ -140,26 +130,26 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
dbPage, err := ToDBUsersPage(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
var items []users.User
if !pm.OnlyTotal {
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrFailedToRetrieveAllGroups, err)
}
defer rows.Close()
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToUser(dbu)
if err != nil {
return users.UsersPage{}, err
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
items = append(items, c)
@@ -170,7 +160,7 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := users.UsersPage{
@@ -242,12 +232,12 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
func (repo *userRepo) update(ctx context.Context, user users.User, query string) (users.User, error) {
dbu, err := toDBUser(user)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
row, err := repo.Repository.DB.NamedQueryContext(ctx, query, dbu)
if err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
@@ -257,7 +247,7 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string)
}
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
return ToUser(dbu)
@@ -308,7 +298,7 @@ func (repo *userRepo) Delete(ctx context.Context, id string) error {
result, err := repo.Repository.DB.ExecContext(ctx, q, id)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
return repo.eh.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
@@ -320,7 +310,7 @@ func (repo *userRepo) Delete(ctx context.Context, id string) error {
func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) {
query, err := PageQuery(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
}
tq := query
@@ -330,12 +320,12 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
dbPage, err := ToDBUsersPage(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -343,7 +333,7 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToUser(dbu)
@@ -358,7 +348,7 @@ func (repo *userRepo) SearchUsers(ctx context.Context, pm users.Page) (users.Use
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := users.UsersPage{
@@ -381,7 +371,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
}
query, err := PageQuery(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrParseQueryParams, err)
}
squery := applyOrdering(query, pm)
@@ -389,11 +379,11 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
u.created_at, u.updated_at, COALESCE(u.updated_by, '') AS updated_by FROM users u %s LIMIT :limit OFFSET :offset;`, squery)
dbPage, err := ToDBUsersPage(pm)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrMarshalBDEntity, err)
}
rows, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
@@ -401,12 +391,12 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
for rows.Next() {
dbu := DBUser{}
if err := rows.StructScan(&dbu); err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
c, err := ToUser(dbu)
if err != nil {
return users.UsersPage{}, err
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrUnmarshalBDEntity, err)
}
items = append(items, c)
@@ -415,7 +405,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
total, err := postgres.Total(ctx, repo.Repository.DB, cq, dbPage)
if err != nil {
return users.UsersPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.UsersPage{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
page := users.UsersPage{
@@ -441,14 +431,14 @@ func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
if err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbu = DBUser{}
if row.Next() {
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return ToUser(dbu)
@@ -468,14 +458,14 @@ func (repo *userRepo) RetrieveByUsername(ctx context.Context, username string) (
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
if err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
defer row.Close()
dbu = DBUser{}
if row.Next() {
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrViewEntity, err)
return users.User{}, repo.eh.HandleError(repoerr.ErrViewEntity, err)
}
return ToUser(dbu)
+24 -21
View File
@@ -82,7 +82,6 @@ func TestUsersSave(t *testing.T) {
user: externalUser,
err: nil,
},
{
desc: "add user with duplicate user email",
user: users.User{
@@ -129,7 +128,7 @@ func TestUsersSave(t *testing.T) {
Metadata: users.Metadata{},
Status: users.EnabledStatus,
},
err: errors.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add user with invalid user name",
@@ -145,7 +144,7 @@ func TestUsersSave(t *testing.T) {
Metadata: users.Metadata{},
Status: users.EnabledStatus,
},
err: errors.ErrMalformedEntity,
err: repoerr.ErrCreateEntity,
},
{
desc: "add user with a missing username",
@@ -194,12 +193,14 @@ func TestUsersSave(t *testing.T) {
}
for _, tc := range cases {
rUser, err := repo.Save(context.Background(), tc.user)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
rUser.Credentials.Secret = tc.user.Credentials.Secret
assert.Equal(t, tc.user, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, rUser))
}
t.Run(tc.desc, func(t *testing.T) {
rUser, err := repo.Save(context.Background(), tc.user)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
rUser.Credentials.Secret = tc.user.Credentials.Secret
assert.Equal(t, tc.user, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, rUser))
}
})
}
}
@@ -1307,7 +1308,7 @@ func TestUpdate(t *testing.T) {
userReq: users.UserReq{
Metadata: &malformedMetadata,
},
err: repoerr.ErrUpdateEntity,
err: repoerr.ErrMalformedEntity,
},
{
desc: "update empty metadata for enabled user",
@@ -1564,7 +1565,7 @@ func TestUpdateUsername(t *testing.T) {
Username: user2.Credentials.Username,
},
},
err: repoerr.ErrConflict,
err: errors.ErrUsernameNotAvailable,
},
{
desc: "for disabled user",
@@ -1984,16 +1985,18 @@ func TestRetrieveByIDs(t *testing.T) {
}
for _, c := range cases {
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
case err == nil:
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
assert.Equal(t, c.response.Total, response.Total)
assert.Equal(t, c.response.Limit, response.Limit)
assert.Equal(t, c.response.Offset, response.Offset)
assert.ElementsMatch(t, response.Users, c.response.Users)
default:
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
}
t.Run(c.desc, func(t *testing.T) {
switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); {
case err == nil:
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err))
assert.Equal(t, c.response.Total, response.Total)
assert.Equal(t, c.response.Limit, response.Limit)
assert.Equal(t, c.response.Offset, response.Offset)
assert.ElementsMatch(t, response.Users, c.response.Users)
default:
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
}
})
}
}
+25 -15
View File
@@ -28,9 +28,11 @@ import (
const defaultUsernamePrefix = "user"
var (
errIssueToken = errors.New("failed to issue token")
errRecoveryToken = errors.New("failed to generate password recovery token")
errLoginDisableUser = errors.New("failed to login in disabled user")
errIssueToken = errors.NewServiceError("failed to issue token")
errRecoveryToken = errors.NewServiceError("failed to generate password recovery token")
errLoginDisableUser = errors.NewAuthNError("failed to login in disabled user")
errMatchUserVerification = errors.NewRequestError("user verification does not match with stored verification")
errSimilarUpdateEmail = errors.NewRequestError("new email is similar to the current email")
usernameRegExp = regexp.MustCompile(`^[a-z0-9][a-z0-9_-]{34}[a-z0-9]$`)
)
@@ -65,28 +67,28 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
userID, err := svc.idProvider.ID()
if err != nil {
return User{}, err
return User{}, errors.Wrap(svcerr.ErrIssueProviderID, err)
}
if u.Credentials.Secret != "" {
hash, err := svc.hasher.Hash(u.Credentials.Secret)
if err != nil {
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, err)
return User{}, errors.Wrap(svcerr.ErrHashPassword, err)
}
u.Credentials.Secret = hash
}
if u.Status != DisabledStatus && u.Status != EnabledStatus {
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidStatus)
return User{}, svcerr.ErrInvalidStatus
}
if u.Role != UserRole && u.Role != AdminRole {
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, svcerr.ErrInvalidRole)
return User{}, svcerr.ErrInvalidRole
}
u.ID = userID
u.CreatedAt = time.Now().UTC()
if err := svc.addUserPolicy(ctx, u.ID, u.Role); err != nil {
return User{}, err
return User{}, errors.Wrap(svcerr.ErrAddPolicies, err)
}
defer func() {
if err != nil {
@@ -105,7 +107,7 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
func (svc service) SendVerification(ctx context.Context, session authn.Session) error {
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
if err != nil {
return err
return errors.Wrap(svcerr.ErrViewEntity, err)
}
if !dbUser.VerifiedAt.IsZero() {
@@ -114,7 +116,7 @@ func (svc service) SendVerification(ctx context.Context, session authn.Session)
uv, err := svc.users.RetrieveUserVerification(ctx, dbUser.ID, dbUser.Email)
if err != nil && err != repoerr.ErrNotFound {
return err
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
if err = uv.Valid(); err != nil {
@@ -150,7 +152,7 @@ func (svc service) VerifyEmail(ctx context.Context, token string) (User, error)
}
if err := stored.Match(received); err != nil {
return User{}, err
return User{}, errors.Wrap(errMatchUserVerification, err)
}
if err := stored.Valid(); err != nil {
@@ -219,8 +221,12 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
if dbUser.Status == DisabledStatus {
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
}
token, err := svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
if err != nil {
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
}
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
return token, nil
}
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
@@ -375,7 +381,7 @@ func (svc service) UpdateEmail(ctx context.Context, session authn.Session, userI
return User{}, svcerr.ErrExternalAuthProviderCouldNotUpdate
}
if oldUsr.Email == email {
return User{}, fmt.Errorf("current email is same as update requested email")
return User{}, errSimilarUpdateEmail
}
usr := User{
@@ -409,7 +415,11 @@ func (svc service) SendPasswordReset(ctx context.Context, email string) error {
return errors.Wrap(errRecoveryToken, err)
}
return svc.email.SendPasswordReset([]string{email}, user.Credentials.Username, token.AccessToken)
if err := svc.email.SendPasswordReset([]string{email}, user.Credentials.Username, token.AccessToken); err != nil {
return errors.NewInternalErrorWithErr(err)
}
return nil
}
func (svc service) ResetSecret(ctx context.Context, session authn.Session, secret string) error {
@@ -548,7 +558,7 @@ func (svc service) changeUserStatus(ctx context.Context, session authn.Session,
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if dbu.Status == user.Status {
return User{}, errors.ErrStatusAlreadyAssigned
return User{}, svcerr.ErrStatusAlreadyAssigned
}
user.UpdatedBy = session.UserID
+296 -264
View File
@@ -165,7 +165,7 @@ func TestRegister(t *testing.T) {
Secret: strings.Repeat("a", 73),
},
},
err: repoerr.ErrMalformedEntity,
err: errHashPassword,
},
{
desc: "register a new user with invalid status",
@@ -221,74 +221,76 @@ func TestRegister(t *testing.T) {
}
for _, tc := range cases {
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
repoCall := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
expected, err := svc.Register(context.Background(), authn.Session{}, tc.user, true)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
tc.user.ID = expected.ID
tc.user.CreatedAt = expected.CreatedAt
tc.user.UpdatedAt = expected.UpdatedAt
tc.user.Credentials.Secret = expected.Credentials.Secret
tc.user.UpdatedBy = expected.UpdatedBy
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
repoCall := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
expected, err := svc.Register(context.Background(), authn.Session{}, tc.user, true)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
tc.user.ID = expected.ID
tc.user.CreatedAt = expected.CreatedAt
tc.user.UpdatedAt = expected.UpdatedAt
tc.user.Credentials.Secret = expected.Credentials.Secret
tc.user.UpdatedBy = expected.UpdatedBy
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
ok := repoCall.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
})
}
svc, _, cRepo, policies, _ = newService()
cases2 := []struct {
desc string
user users.User
session authn.Session
addPoliciesResponseErr error
deletePoliciesResponseErr error
saveErr error
checkSuperAdminErr error
err error
}{
{
desc: "register new user successfully as admin",
user: user,
session: authn.Session{UserID: validID, SuperAdmin: true},
err: nil,
},
{
desc: "register a new user as admin with failed check on super admin",
user: user,
session: authn.Session{UserID: validID, SuperAdmin: false},
checkSuperAdminErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
}
for _, tc := range cases2 {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
tc.user.ID = expected.ID
tc.user.CreatedAt = expected.CreatedAt
tc.user.UpdatedAt = expected.UpdatedAt
tc.user.Credentials.Secret = expected.Credentials.Secret
tc.user.UpdatedBy = expected.UpdatedBy
assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
}
repoCall1.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall.Unset()
}
// cases2 := []struct {
// desc string
// user users.User
// session authn.Session
// addPoliciesResponseErr error
// deletePoliciesResponseErr error
// saveErr error
// checkSuperAdminErr error
// err error
// }{
// {
// desc: "register new user successfully as admin",
// user: user,
// session: authn.Session{UserID: validID, SuperAdmin: true},
// err: nil,
// },
// {
// desc: "register a new user as admin with failed check on super admin",
// user: user,
// session: authn.Session{UserID: validID, SuperAdmin: false},
// checkSuperAdminErr: svcerr.ErrAuthorization,
// err: svcerr.ErrAuthorization,
// },
// }
// for _, tc := range cases2 {
// repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
// policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesResponseErr)
// policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesResponseErr)
// repoCall1 := cRepo.On("Save", context.Background(), mock.Anything).Return(tc.user, tc.saveErr)
// expected, err := svc.Register(context.Background(), authn.Session{UserID: validID}, tc.user, false)
// assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
// if err == nil {
// tc.user.ID = expected.ID
// tc.user.CreatedAt = expected.CreatedAt
// tc.user.UpdatedAt = expected.UpdatedAt
// tc.user.Credentials.Secret = expected.Credentials.Secret
// tc.user.UpdatedBy = expected.UpdatedBy
// assert.Equal(t, tc.user, expected, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.user, expected))
// ok := repoCall1.Parent.AssertCalled(t, "Save", context.Background(), mock.Anything)
// assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc))
// }
// repoCall1.Unset()
// policyCall.Unset()
// policyCall1.Unset()
// repoCall.Unset()
// }
}
func TestViewUser(t *testing.T) {
@@ -349,18 +351,20 @@ func TestViewUser(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
rUser, err := svc.View(context.Background(), authn.Session{UserID: tc.reqUserID}, tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
tc.response.Credentials.Secret = ""
assert.Equal(t, tc.response, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.userID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
repoCall1.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
rUser, err := svc.View(context.Background(), authn.Session{UserID: tc.reqUserID}, tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
tc.response.Credentials.Secret = ""
assert.Equal(t, tc.response, rUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.userID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
repoCall1.Unset()
repoCall.Unset()
})
}
}
@@ -430,17 +434,19 @@ func TestListUsers(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.superAdminErr)
repoCall1 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListUsers(context.Background(), authn.Session{UserID: user.ID}, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.superAdminErr)
repoCall1 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr)
page, err := svc.ListUsers(context.Background(), authn.Session{UserID: user.ID}, tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
@@ -494,11 +500,13 @@ func TestSearchUsers(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("SearchUsers", context.Background(), mock.Anything).Return(tc.response, tc.responseErr)
page, err := svc.SearchUsers(context.Background(), tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("SearchUsers", context.Background(), mock.Anything).Return(tc.response, tc.responseErr)
page, err := svc.SearchUsers(context.Background(), tc.page)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page))
repoCall.Unset()
})
}
}
@@ -673,19 +681,21 @@ func TestUpdateUser(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateResponse, tc.updateErr)
updatedUser, err := svc.Update(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedUser))
if tc.err == nil {
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateResponse, tc.updateErr)
updatedUser, err := svc.Update(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateResponse, updatedUser))
if tc.err == nil {
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -750,18 +760,20 @@ func TestUpdateTags(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateUserTagsResponse, tc.updateUserTagsErr)
updatedUser, err := svc.UpdateTags(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateUserTagsResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUserTagsResponse, updatedUser))
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateUserTagsResponse, tc.updateUserTagsErr)
updatedUser, err := svc.UpdateTags(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateUserTagsResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUserTagsResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
@@ -843,22 +855,24 @@ func TestUpdateRole(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
policyCall.Unset()
policyCall1.Unset()
repoCall1.Unset()
})
}
}
@@ -919,7 +933,7 @@ func TestUpdateSecret(t *testing.T) {
err: repoerr.ErrNotFound,
},
{
desc: "update user secret with invalod old secret",
desc: "update user secret with invalid old secret",
oldSecret: "invalid",
newSecret: newSecret,
session: authn.Session{UserID: user.ID},
@@ -934,7 +948,7 @@ func TestUpdateSecret(t *testing.T) {
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: user,
retrieveByEmailResponse: rUser,
err: repoerr.ErrMalformedEntity,
err: errHashPassword,
},
{
desc: "update user secret with failed to update secret",
@@ -950,25 +964,27 @@ func TestUpdateSecret(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByID", context.Background(), user.ID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall1 := cRepo.On("RetrieveByUsername", context.Background(), user.Credentials.Username).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr)
repoCall2 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
authCall := authUser.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
updatedUser, err := svc.UpdateSecret(context.Background(), tc.session, tc.oldSecret, tc.newSecret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, updatedUser))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.response.ID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall1.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.response.Credentials.Username)
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
authCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByID", context.Background(), user.ID).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall1 := cRepo.On("RetrieveByUsername", context.Background(), user.Credentials.Username).Return(tc.retrieveByEmailResponse, tc.retrieveByEmailErr)
repoCall2 := cRepo.On("UpdateSecret", context.Background(), mock.Anything).Return(tc.updateSecretResponse, tc.updateSecretErr)
authCall := authUser.On("Issue", context.Background(), mock.Anything).Return(tc.issueResponse, tc.issueErr)
updatedUser, err := svc.UpdateSecret(context.Background(), tc.session, tc.oldSecret, tc.newSecret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, updatedUser))
if tc.err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.response.ID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall1.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.response.Credentials.Username)
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
authCall.Unset()
})
}
}
@@ -1049,20 +1065,22 @@ func TestUpdateEmail(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
if tc.err == nil && user2.Email != tc.email {
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
user2.Email = tc.email
}
repoCall.Unset()
repocall2.Unset()
repoCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
if tc.err == nil && user2.Email != tc.email {
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
user2.Email = tc.email
}
repoCall.Unset()
repocall2.Unset()
repoCall1.Unset()
})
}
}
@@ -1152,19 +1170,21 @@ func TestUpdateProfilePicture(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateProfilePicResponse, tc.updateProfilePicErr)
updatedUser, err := svc.UpdateProfilePicture(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateProfilePicResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateProfilePicResponse, updatedUser))
if tc.err == nil {
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.userID).Return(tc.retrieveByIDResp, tc.retrieveByIDErr)
repoCall2 := cRepo.On("Update", context.Background(), tc.userID, mock.Anything).Return(tc.updateProfilePicResponse, tc.updateProfilePicErr)
updatedUser, err := svc.UpdateProfilePicture(context.Background(), tc.session, tc.userID, tc.userReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateProfilePicResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateProfilePicResponse, updatedUser))
if tc.err == nil {
ok := repoCall2.Parent.AssertCalled(t, "Update", context.Background(), tc.userID, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1224,17 +1244,19 @@ func TestUpdateUsername(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("UpdateUsername", context.Background(), mock.Anything).Return(tc.updateUsernameResponse, tc.updateUsernameErr)
updatedUser, err := svc.UpdateUsername(context.Background(), tc.session, tc.user.ID, tc.user.Credentials.Username)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateUsernameResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUsernameResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "UpdateUsername", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateUsername was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("UpdateUsername", context.Background(), mock.Anything).Return(tc.updateUsernameResponse, tc.updateUsernameErr)
updatedUser, err := svc.UpdateUsername(context.Background(), tc.session, tc.user.ID, tc.user.Credentials.Username)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateUsernameResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateUsernameResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "UpdateUsername", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("UpdateUsername was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
})
}
}
@@ -1287,7 +1309,7 @@ func TestEnableUser(t *testing.T) {
id: enabledUser1.ID,
user: enabledUser1,
retrieveByIDResponse: enabledUser1,
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "enable disabled user with failed to change status",
@@ -1301,21 +1323,23 @@ func TestEnableUser(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
_, err := svc.Enable(context.Background(), authn.Session{}, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
_, err := svc.Enable(context.Background(), authn.Session{}, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1368,7 +1392,7 @@ func TestDisableUser(t *testing.T) {
id: disabledUser1.ID,
user: disabledUser1,
retrieveByIDResponse: disabledUser1,
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "disable enabled user with failed to change status",
@@ -1381,21 +1405,23 @@ func TestDisableUser(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall2 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
_, err := svc.Disable(context.Background(), authn.Session{}, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
_, err := svc.Disable(context.Background(), authn.Session{}, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
@@ -1445,7 +1471,7 @@ func TestDeleteUser(t *testing.T) {
user: deletedUser1,
session: authn.Session{UserID: validID, SuperAdmin: true},
retrieveByIDResponse: deletedUser1,
err: errors.ErrStatusAlreadyAssigned,
err: svcerr.ErrStatusAlreadyAssigned,
},
{
desc: "delete enabled user with failed to change status",
@@ -1460,20 +1486,22 @@ func TestDeleteUser(t *testing.T) {
}
for _, tc := range cases {
repoCall2 := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall3 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall4 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
err := svc.Delete(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall3.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall4.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall2 := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall3 := cRepo.On("RetrieveByID", context.Background(), tc.id).Return(tc.retrieveByIDResponse, tc.retrieveByIDErr)
repoCall4 := cRepo.On("ChangeStatus", context.Background(), mock.Anything).Return(tc.changeStatusResponse, tc.changeStatusErr)
err := svc.Delete(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil {
ok := repoCall3.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.id)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = repoCall4.Parent.AssertCalled(t, "ChangeStatus", context.Background(), mock.Anything)
assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc))
}
repoCall2.Unset()
repoCall3.Unset()
repoCall4.Unset()
})
}
}
@@ -1542,20 +1570,22 @@ func TestIssueToken(t *testing.T) {
}
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
})
}
}
@@ -1619,20 +1649,22 @@ func TestRefreshToken(t *testing.T) {
}
for _, tc := range cases {
authCall := authsvc.On("Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr)
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
token, err := svc.RefreshToken(context.Background(), tc.session, validToken)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken})
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
authCall := authsvc.On("Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken}).Return(tc.refreshResp, tc.refresErr)
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
token, err := svc.RefreshToken(context.Background(), tc.session, validToken)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := authCall.Parent.AssertCalled(t, "Refresh", context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: validToken})
assert.True(t, ok, fmt.Sprintf("Refresh was not called on %s", tc.desc))
ok = repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
}
authCall.Unset()
repoCall.Unset()
})
}
}
+18 -13
View File
@@ -54,7 +54,6 @@ var (
errClientNotInitialized = errors.New("client is not initialized")
errMissingTopicPub = errors.New("failed to publish due to missing topic")
errMissingTopicSub = errors.New("failed to subscribe due to missing topic")
errFailedPublish = errors.New("failed to publish")
)
var (
@@ -164,7 +163,7 @@ func TestAuthPublish(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with nil session",
@@ -228,7 +227,7 @@ func TestAuthPublish(t *testing.T) {
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with unauthorized token",
@@ -275,7 +274,7 @@ func TestAuthPublish(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with b64 encoded credentials",
@@ -306,7 +305,7 @@ func TestAuthPublish(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with health check topic successfully",
@@ -330,7 +329,7 @@ func TestAuthPublish(t *testing.T) {
authKey: clientKey,
payload: &payload,
status: http.StatusBadRequest,
err: errors.Wrap(errFailedPublish, messaging.ErrMalformedTopic),
err: messaging.ErrMalformedTopic,
clientType: policies.ClientType,
},
}
@@ -359,7 +358,9 @@ func TestAuthPublish(t *testing.T) {
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
@@ -451,7 +452,7 @@ func TestAuthSubscribe(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with empty topics",
@@ -512,7 +513,7 @@ func TestAuthSubscribe(t *testing.T) {
authNRes1: smqauthn.Session{},
authNErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with unauthorized token",
@@ -556,7 +557,7 @@ func TestAuthSubscribe(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "publish with b64 encoded credentials",
@@ -585,7 +586,7 @@ func TestAuthSubscribe(t *testing.T) {
authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue),
authNRes: &grpcClientsV1.AuthnRes{Authenticated: false},
status: http.StatusUnauthorized,
err: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
err: svcerr.ErrAuthentication,
},
{
desc: "subscribe with health check topic successfully",
@@ -636,7 +637,9 @@ func TestAuthSubscribe(t *testing.T) {
if ok {
assert.Equal(t, tc.status, hpe.StatusCode())
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
authCall.Unset()
clientsCall.Unset()
channelsCall.Unset()
@@ -727,7 +730,9 @@ func TestPublish(t *testing.T) {
}
repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
err := handler.Publish(ctx, &tc.topic, &tc.payload)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected: %v, got: %v", tc.err, err))
if tc.err != nil {
assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err))
}
repoCall.Unset()
}
}