MG-12 - Improve Error Handling (#18)

* Add service error type

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Wrap errors in users service

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Solve merge errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Wrap errors in users service

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Wrap errors in things service

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Wrap errors in twins

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Wrap errors in bootstrap

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update provision

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update error tags

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Remove repo errors from transport layer

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Add copyright headers

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Resolve conflicts

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Remove apiutil from service

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update postgres errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Handle token errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Handle token errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update auth errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update auth errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update errors in auth

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update users service

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update license header

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* fix ci

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update error definitions

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update error wrapping

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* update error type definitions

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* update error type definitions

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* update error type definitions

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Rename import aliases

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Rename import aliases

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Refactor postgres errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update errors

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Fix ci

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update errors in transport

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* Update errors in transport

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

* update error definitions

Signed-off-by: felix.gateru <felix.gateru@gmail.com>

---------

Signed-off-by: felix.gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2023-11-24 02:52:04 +03:00
committed by GitHub
parent dd3fb241b0
commit 705560efd6
33 changed files with 486 additions and 328 deletions
+14 -4
View File
@@ -21,6 +21,16 @@ var (
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
// ErrExpiry indicates that the token is expired.
ErrExpiry = errors.New("token is expired")
// ErrSetClaim indicates an inability to set the claim.
ErrSetClaim = errors.New("failed to set claim")
// ErrSignJWT indicates an error in signing jwt token.
ErrSignJWT = errors.New("failed to sign jwt token")
// ErrParseToken indicates a failure to parse the token.
ErrParseToken = errors.New("failed to parse token")
// ErrValidateJWTToken indicates a failure to validate JWT token.
ErrValidateJWTToken = errors.New("failed to validate jwt token")
// ErrJSONHandle indicates an error in handling JSON.
ErrJSONHandle = errors.New("failed to perform operation JSON")
)
const (
@@ -62,7 +72,7 @@ func (repo *tokenizer) Issue(key auth.Key) (string, error) {
}
signedTkn, err := jwt.Sign(tkn, jwt.WithKey(jwa.HS512, repo.secret))
if err != nil {
return "", err
return "", errors.Wrap(ErrSignJWT, err)
}
return string(signedTkn), nil
}
@@ -87,16 +97,16 @@ func (repo *tokenizer) Parse(token string) (auth.Key, error) {
return nil
})
if err := jwt.Validate(tkn, jwt.WithValidator(validator)); err != nil {
return auth.Key{}, err
return auth.Key{}, errors.Wrap(ErrValidateJWTToken, err)
}
jsn, err := json.Marshal(tkn.PrivateClaims())
if err != nil {
return auth.Key{}, err
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
var key auth.Key
if err := json.Unmarshal(jsn, &key); err != nil {
return auth.Key{}, err
return auth.Key{}, errors.Wrap(ErrJSONHandle, err)
}
tType, ok := tkn.Get(tokenType)
+5 -4
View File
@@ -8,6 +8,7 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
@@ -23,7 +24,7 @@ type Service struct {
func (m *Service) Issue(ctx context.Context, in *magistrala.IssueReq, opts ...grpc.CallOption) (*magistrala.Token, error) {
ret := m.Called(ctx, in)
if in.GetUserId() == InvalidValue || in.GetUserId() == "" {
return &magistrala.Token{}, errors.ErrAuthentication
return &magistrala.Token{}, svcerr.ErrAuthentication
}
return ret.Get(0).(*magistrala.Token), ret.Error(1)
@@ -32,7 +33,7 @@ func (m *Service) Issue(ctx context.Context, in *magistrala.IssueReq, opts ...gr
func (m *Service) Refresh(ctx context.Context, in *magistrala.RefreshReq, opts ...grpc.CallOption) (*magistrala.Token, error) {
ret := m.Called(ctx, in)
if in.GetRefreshToken() == InvalidValue || in.GetRefreshToken() == "" {
return &magistrala.Token{}, errors.ErrAuthentication
return &magistrala.Token{}, svcerr.ErrAuthentication
}
return ret.Get(0).(*magistrala.Token), ret.Error(1)
@@ -41,7 +42,7 @@ func (m *Service) Refresh(ctx context.Context, in *magistrala.RefreshReq, opts .
func (m *Service) Identify(ctx context.Context, in *magistrala.IdentityReq, opts ...grpc.CallOption) (*magistrala.IdentityRes, error) {
ret := m.Called(ctx, in)
if in.GetToken() == InvalidValue || in.GetToken() == "" {
return &magistrala.IdentityRes{}, errors.ErrAuthentication
return &magistrala.IdentityRes{}, svcerr.ErrAuthentication
}
return ret.Get(0).(*magistrala.IdentityRes), ret.Error(1)
@@ -50,7 +51,7 @@ func (m *Service) Identify(ctx context.Context, in *magistrala.IdentityReq, opts
func (m *Service) Authorize(ctx context.Context, in *magistrala.AuthorizeReq, opts ...grpc.CallOption) (*magistrala.AuthorizeRes, error) {
ret := m.Called(ctx, in)
if in.GetSubject() == InvalidValue || in.GetSubject() == "" {
return &magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization
return &magistrala.AuthorizeRes{Authorized: false}, svcerr.ErrAuthorization
}
if in.GetObject() == InvalidValue || in.GetObject() == "" {
return &magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization
+14 -13
View File
@@ -16,6 +16,7 @@ import (
"github.com/absmach/magistrala/internal/postgres"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/jackc/pgtype"
"github.com/jmoiron/sqlx"
)
@@ -41,24 +42,24 @@ func (repo domainRepo) Save(ctx context.Context, d auth.Domain) (ad auth.Domain,
dbd, err := toDBDomains(d)
if err != nil {
return auth.Domain{}, errors.Wrap(errors.ErrCreateEntity, err)
return auth.Domain{}, errors.Wrap(repoerr.ErrCreateEntity, repoerr.ErrRollbackTx)
}
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
if err != nil {
return auth.Domain{}, postgres.HandleError(err, errors.ErrCreateEntity)
return auth.Domain{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
row.Next()
dbd = dbDomain{}
if err := row.StructScan(&dbd); err != nil {
return auth.Domain{}, err
return auth.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
domain, err := toDomain(dbd)
if err != nil {
return auth.Domain{}, err
return auth.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
return domain, nil
@@ -99,7 +100,7 @@ func (repo domainRepo) RetrieveAllByIDs(ctx context.Context, pm auth.Page) (auth
}
query, err := buildPageQuery(pm)
if err != nil {
return auth.DomainsPage{}, err
return auth.DomainsPage{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
if query == "" {
return auth.DomainsPage{}, nil
@@ -147,7 +148,7 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm auth.Page) (auth.Doma
var q string
query, err := buildPageQuery(pm)
if err != nil {
return auth.DomainsPage{}, err
return auth.DomainsPage{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
if query == "" {
return auth.DomainsPage{}, nil
@@ -236,19 +237,19 @@ func (repo domainRepo) Update(ctx context.Context, id string, userID string, dr
}
row, err := repo.db.NamedQueryContext(ctx, q, dbd)
if err != nil {
return auth.Domain{}, postgres.HandleError(err, errors.ErrUpdateEntity)
return auth.Domain{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
// defer row.Close()
row.Next()
dbd = dbDomain{}
if err := row.StructScan(&dbd); err != nil {
return auth.Domain{}, err
return auth.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
domain, err := toDomain(dbd)
if err != nil {
return auth.Domain{}, err
return auth.Domain{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
return domain, nil
@@ -265,7 +266,7 @@ func (repo domainRepo) Delete(ctx context.Context, id string) error {
row, err := repo.db.NamedQueryContext(ctx, q, nil)
if err != nil {
return postgres.HandleError(err, errors.ErrRemoveEntity)
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
defer row.Close()
@@ -281,7 +282,7 @@ func (repo domainRepo) SavePolicies(ctx context.Context, pcs ...auth.Policy) err
dbpc := toDBPolicies(pcs...)
row, err := repo.db.NamedQueryContext(ctx, q, dbpc)
if err != nil {
return postgres.HandleError(err, errors.ErrCreateEntity)
return postgres.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
@@ -305,7 +306,7 @@ func (repo domainRepo) CheckPolicy(ctx context.Context, pc auth.Policy) error {
dbpc := toDBPolicy(pc)
row, err := repo.db.NamedQueryContext(ctx, q, dbpc)
if err != nil {
return postgres.HandleError(err, errors.ErrCreateEntity)
return postgres.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
row.Next()
@@ -345,7 +346,7 @@ func (repo domainRepo) DeletePolicies(ctx context.Context, pcs ...auth.Policy) (
dbpc := toDBPolicy(pc)
row, err := tx.NamedQuery(q, dbpc)
if err != nil {
return postgres.HandleError(err, errors.ErrRemoveEntity)
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
defer row.Close()
}
+2 -2
View File
@@ -37,7 +37,7 @@ func (kr *repo) Save(ctx context.Context, key auth.Key) (string, error) {
dbKey := toDBKey(key)
if _, err := kr.db.NamedExecContext(ctx, q, dbKey); err != nil {
return "", postgres.HandleError(err, errSave)
return "", postgres.HandleError(errSave, err)
}
return dbKey.ID, nil
@@ -51,7 +51,7 @@ func (kr *repo) Retrieve(ctx context.Context, issuerID, id string) (auth.Key, er
return auth.Key{}, errors.ErrNotFound
}
return auth.Key{}, postgres.HandleError(err, errRetrieve)
return auth.Key{}, postgres.HandleError(errRetrieve, err)
}
return toKey(key), nil
+47 -43
View File
@@ -10,9 +10,9 @@ import (
"time"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/internal/apiutil"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
)
const (
@@ -42,12 +42,16 @@ var (
// ErrFailedToRetrieveChildren failed to retrieve groups.
ErrFailedToRetrieveChildren = errors.New("failed to retrieve all groups")
errIssueUser = errors.New("failed to issue new login key")
errIssueTmp = errors.New("failed to issue new temporary key")
errRevoke = errors.New("failed to remove key")
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errIssueUser = errors.New("failed to issue new login key")
errIssueTmp = errors.New("failed to issue new temporary key")
errRevoke = errors.New("failed to remove key")
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errCreateDomainPolicy = errors.New("failed to create domain policy")
errAddPolicies = errors.New("failed to add policies")
errRemovePolicies = errors.New("failed to remove the policies")
errInvalidPolicy = errors.New("failed to validate policy")
)
// Authn specifies an API that must be fullfiled by the domain service
@@ -159,40 +163,40 @@ func (svc service) Identify(ctx context.Context, token string) (Key, error) {
case APIKey:
_, err := svc.keys.Retrieve(ctx, key.Issuer, key.ID)
if err != nil {
return Key{}, errors.ErrAuthentication
return Key{}, svcerr.ErrAuthentication
}
return key, nil
default:
return Key{}, errors.ErrAuthentication
return Key{}, svcerr.ErrAuthentication
}
}
func (svc service) Authorize(ctx context.Context, pr PolicyReq) error {
if err := svc.PolicyValidation(pr); err != nil {
return err
return errors.Wrap(errInvalidPolicy, err)
}
if pr.SubjectKind == TokenKind {
key, err := svc.Identify(ctx, pr.Subject)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if key.Subject == "" {
if pr.ObjectType == GroupType || pr.ObjectType == ThingType || pr.ObjectType == DomainType {
return errors.ErrDomainAuthorization
}
return errors.ErrAuthentication
return svcerr.ErrAuthentication
}
pr.Subject = key.Subject
}
if err := svc.agent.CheckPolicy(ctx, pr); err != nil {
return errors.Wrap(errors.ErrAuthorization, err)
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
func (svc service) AddPolicy(ctx context.Context, pr PolicyReq) error {
if err := svc.PolicyValidation(pr); err != nil {
return err
return errors.Wrap(errInvalidPolicy, err)
}
return svc.agent.AddPolicy(ctx, pr)
}
@@ -207,7 +211,7 @@ func (svc service) PolicyValidation(pr PolicyReq) error {
func (svc service) AddPolicies(ctx context.Context, prs []PolicyReq) error {
for _, pr := range prs {
if err := svc.PolicyValidation(pr); err != nil {
return err
return errors.Wrap(errInvalidPolicy, err)
}
}
return svc.agent.AddPolicies(ctx, prs)
@@ -220,7 +224,7 @@ func (svc service) DeletePolicy(ctx context.Context, pr PolicyReq) error {
func (svc service) DeletePolicies(ctx context.Context, prs []PolicyReq) error {
for _, pr := range prs {
if err := svc.PolicyValidation(pr); err != nil {
return err
return errors.Wrap(errInvalidPolicy, err)
}
}
return svc.agent.DeletePolicies(ctx, prs)
@@ -232,7 +236,7 @@ func (svc service) ListObjects(ctx context.Context, pr PolicyReq, nextPageToken
}
res, npt, err := svc.agent.RetrieveObjects(ctx, pr, nextPageToken, limit)
if err != nil {
return PolicyPage{}, err
return PolicyPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
var page PolicyPage
for _, tuple := range res {
@@ -245,13 +249,13 @@ func (svc service) ListObjects(ctx context.Context, pr PolicyReq, nextPageToken
func (svc service) ListAllObjects(ctx context.Context, pr PolicyReq) (PolicyPage, error) {
res, err := svc.agent.RetrieveAllObjects(ctx, pr)
if err != nil {
return PolicyPage{}, err
return PolicyPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
var page PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Object)
}
return page, err
return page, errors.Wrap(errors.ErrNotFound, err)
}
func (svc service) CountObjects(ctx context.Context, pr PolicyReq) (int, error) {
@@ -264,26 +268,26 @@ func (svc service) ListSubjects(ctx context.Context, pr PolicyReq, nextPageToken
}
res, npt, err := svc.agent.RetrieveSubjects(ctx, pr, nextPageToken, limit)
if err != nil {
return PolicyPage{}, err
return PolicyPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
var page PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Subject)
}
page.NextPageToken = npt
return page, err
return page, errors.Wrap(svcerr.ErrNotFound, err)
}
func (svc service) ListAllSubjects(ctx context.Context, pr PolicyReq) (PolicyPage, error) {
res, err := svc.agent.RetrieveAllSubjects(ctx, pr)
if err != nil {
return PolicyPage{}, err
return PolicyPage{}, errors.Wrap(svcerr.ErrNotFound, err)
}
var page PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Subject)
}
return page, err
return page, errors.Wrap(svcerr.ErrNotFound, err)
}
func (svc service) CountSubjects(ctx context.Context, pr PolicyReq) (int, error) {
@@ -307,7 +311,7 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
return Token{}, err
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
access, err := svc.tokenizer.Issue(key)
@@ -327,7 +331,7 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) {
k, err := svc.tokenizer.Parse(token)
if err != nil {
return Token{}, err
return Token{}, errors.Wrap(errRetrieve, err)
}
if k.Type != RefreshKey {
return Token{}, errIssueUser
@@ -341,7 +345,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
return Token{}, err
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
key.ExpiresAt = time.Now().Add(svc.loginDuration)
@@ -419,11 +423,11 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e
func (svc service) authenticate(token string) (string, string, error) {
key, err := svc.tokenizer.Parse(token)
if err != nil {
return "", "", err
return "", "", errors.Wrap(svcerr.ErrAuthentication, err)
}
// Only login key token is valid for login.
if key.Type != AccessKey || key.Issuer == "" {
return "", "", errors.ErrAuthentication
return "", "", svcerr.ErrAuthentication
}
return key.Issuer, key.Subject, nil
@@ -448,24 +452,24 @@ func SwitchToPermission(relation string) string {
func (svc service) CreateDomain(ctx context.Context, token string, d Domain) (do Domain, err error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
d.CreatedBy = key.User
domainID, err := svc.idProvider.ID()
if err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrUniqueID, err)
}
d.ID = domainID
if d.Status != clients.DisabledStatus && d.Status != clients.EnabledStatus {
return Domain{}, apiutil.ErrInvalidStatus
return Domain{}, svcerr.ErrInvalidStatus
}
d.CreatedAt = time.Now()
if err := svc.createDomainPolicy(ctx, key.User, domainID, AdministratorRelation); err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(errCreateDomainPolicy, err)
}
defer func() {
if err != nil {
@@ -487,7 +491,7 @@ func (svc service) RetrieveDomain(ctx context.Context, token string, id string)
ObjectType: DomainType,
Permission: ViewPermission,
}); err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return svc.domains.RetrieveByID(ctx, id)
@@ -496,7 +500,7 @@ func (svc service) RetrieveDomain(ctx context.Context, token string, id string)
func (svc service) UpdateDomain(ctx context.Context, token string, id string, d DomainReq) (Domain, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, PolicyReq{
Subject: key.Subject,
@@ -506,7 +510,7 @@ func (svc service) UpdateDomain(ctx context.Context, token string, id string, d
ObjectType: DomainType,
Permission: EditPermission,
}); err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return svc.domains.Update(ctx, id, key.User, d)
}
@@ -514,7 +518,7 @@ func (svc service) UpdateDomain(ctx context.Context, token string, id string, d
func (svc service) ChangeDomainStatus(ctx context.Context, token string, id string, d DomainReq) (Domain, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, PolicyReq{
Subject: key.Subject,
@@ -524,7 +528,7 @@ func (svc service) ChangeDomainStatus(ctx context.Context, token string, id stri
ObjectType: DomainType,
Permission: AdminPermission,
}); err != nil {
return Domain{}, err
return Domain{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return svc.domains.Update(ctx, id, key.User, d)
}
@@ -532,7 +536,7 @@ func (svc service) ChangeDomainStatus(ctx context.Context, token string, id stri
func (svc service) ListDomains(ctx context.Context, token string, p Page) (DomainsPage, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return DomainsPage{}, err
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
p.SubjectID = key.User
if err := svc.Authorize(ctx, PolicyReq{
@@ -578,7 +582,7 @@ func (svc service) AssignUsers(ctx context.Context, token string, id string, use
Object: MagistralaObject,
ObjectType: PlatformType,
}); err != nil {
return errors.Wrap(errors.ErrMalformedEntity, fmt.Errorf("invalid user id : %s ", userID))
return errors.Wrap(svcerr.ErrMalformedEntity, fmt.Errorf("invalid user id : %s ", userID))
}
}
@@ -615,7 +619,7 @@ func (svc service) UnassignUsers(ctx context.Context, token string, id string, u
func (svc service) ListUserDomains(ctx context.Context, token string, userID string, p Page) (DomainsPage, error) {
res, err := svc.Identify(ctx, token)
if err != nil {
return DomainsPage{}, err
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.Authorize(ctx, PolicyReq{
Subject: res.User,
@@ -624,7 +628,7 @@ func (svc service) ListUserDomains(ctx context.Context, token string, userID str
Object: MagistralaObject,
ObjectType: PlatformType,
}); err != nil {
return DomainsPage{}, err
return DomainsPage{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
if userID != "" && res.User != userID {
p.SubjectID = userID
@@ -656,7 +660,7 @@ func (svc service) addDomainPolicies(ctx context.Context, domainID, relation str
})
}
if err := svc.agent.AddPolicies(ctx, prs); err != nil {
return err
return errors.Wrap(errAddPolicies, err)
}
defer func() {
if err != nil {
@@ -762,7 +766,7 @@ func (svc service) removeDomainPolicies(ctx context.Context, domainID, relation
})
}
if err := svc.agent.DeletePolicies(ctx, prs); err != nil {
return err
return errors.Wrap(errRemovePolicies, err)
}
return svc.domains.DeletePolicies(ctx, pcs...)
+13 -11
View File
@@ -11,6 +11,8 @@ import (
"github.com/absmach/magistrala/auth"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
)
@@ -49,13 +51,13 @@ func (pa *policyAgent) CheckPolicy(ctx context.Context, pr auth.PolicyReq) error
resp, err := pa.permissionClient.CheckPermission(ctx, &checkReq)
if err != nil {
return errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errPermission, err))
return errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errPermission, err))
}
if resp.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
return nil
}
if reason, ok := v1.CheckPermissionResponse_Permissionship_name[int32(resp.Permissionship)]; ok {
return errors.Wrap(errors.ErrAuthorization, errors.New(reason))
return errors.Wrap(svcerr.ErrAuthorization, errors.New(reason))
}
return errors.ErrAuthorization
}
@@ -83,7 +85,7 @@ func (pa *policyAgent) AddPolicies(ctx context.Context, prs []auth.PolicyReq) er
}
_, err := pa.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: preconds})
if err != nil {
return errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errAddPolicies, err))
return errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errAddPolicies, err))
}
return nil
}
@@ -106,7 +108,7 @@ func (pa *policyAgent) AddPolicy(ctx context.Context, pr auth.PolicyReq) error {
}
_, err = pa.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: precond})
if err != nil {
return errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errAddPolicies, err))
return errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errAddPolicies, err))
}
return nil
}
@@ -128,7 +130,7 @@ func (pa *policyAgent) DeletePolicies(ctx context.Context, prs []auth.PolicyReq)
}
_, err := pa.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates})
if err != nil {
return errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errRemovePolicies, err))
return errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errRemovePolicies, err))
}
return nil
}
@@ -149,7 +151,7 @@ func (pa *policyAgent) DeletePolicy(ctx context.Context, pr auth.PolicyReq) erro
},
}
if _, err := pa.permissionClient.DeleteRelationships(ctx, req); err != nil {
return errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errRemovePolicies, err))
return errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errRemovePolicies, err))
}
return nil
}
@@ -167,7 +169,7 @@ func (pa *policyAgent) RetrieveObjects(ctx context.Context, pr auth.PolicyReq, n
}
stream, err := pa.permissionClient.LookupResources(ctx, resourceReq)
if err != nil {
return nil, "", errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
return nil, "", errors.Wrap(repoerr.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
}
resources := []*v1.LookupResourcesResponse{}
var token string
@@ -185,7 +187,7 @@ func (pa *policyAgent) RetrieveObjects(ctx context.Context, pr auth.PolicyReq, n
if len(resources) > 0 && resources[len(resources)-1].AfterResultCursor != nil {
token = resources[len(resources)-1].AfterResultCursor.Token
}
return objectsToAuthPolicies(resources), token, errors.Wrap(errors.ErrViewEntity, err)
return objectsToAuthPolicies(resources), token, errors.Wrap(repoerr.ErrViewEntity, err)
}
}
}
@@ -198,7 +200,7 @@ func (pa *policyAgent) RetrieveAllObjects(ctx context.Context, pr auth.PolicyReq
}
stream, err := pa.permissionClient.LookupResources(ctx, resourceReq)
if err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
return nil, errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
}
tuples := []auth.PolicyRes{}
for {
@@ -245,7 +247,7 @@ func (pa *policyAgent) RetrieveSubjects(ctx context.Context, pr auth.PolicyReq,
}
stream, err := pa.permissionClient.LookupSubjects(ctx, &subjectsReq)
if err != nil {
return nil, "", errors.Wrap(errors.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
return nil, "", errors.Wrap(svcerr.ErrMalformedEntity, errors.Wrap(errRetrievePolicies, err))
}
subjects := []*v1.LookupSubjectsResponse{}
var token string
@@ -264,7 +266,7 @@ func (pa *policyAgent) RetrieveSubjects(ctx context.Context, pr auth.PolicyReq,
if len(subjects) > 0 && subjects[len(subjects)-1].AfterResultCursor != nil {
token = subjects[len(subjects)-1].AfterResultCursor.Token
}
return subjectsToAuthPolicies(subjects), token, errors.Wrap(errors.ErrViewEntity, err)
return subjectsToAuthPolicies(subjects), token, errors.Wrap(repoerr.ErrViewEntity, err)
}
}
}
+10 -9
View File
@@ -12,6 +12,7 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
)
@@ -122,7 +123,7 @@ func New(auth magistrala.AuthServiceClient, configs ConfigRepository, sdk mgsdk.
func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (Config, error) {
owner, err := bs.identify(ctx, token)
if err != nil {
return Config{}, err
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
toConnect := bs.toIDList(cfg.Channels)
@@ -169,7 +170,7 @@ func (bs bootstrapService) Add(ctx context.Context, token string, cfg Config) (C
func (bs bootstrapService) View(ctx context.Context, token, id string) (Config, error) {
owner, err := bs.identify(ctx, token)
if err != nil {
return Config{}, err
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return bs.configs.RetrieveByID(ctx, owner, id)
@@ -178,7 +179,7 @@ func (bs bootstrapService) View(ctx context.Context, token, id string) (Config,
func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config) error {
owner, err := bs.identify(ctx, token)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg.Owner = owner
@@ -189,7 +190,7 @@ func (bs bootstrapService) Update(ctx context.Context, token string, cfg Config)
func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clientCert, clientKey, caCert string) (Config, error) {
owner, err := bs.identify(ctx, token)
if err != nil {
return Config{}, err
return Config{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.UpdateCert(ctx, owner, thingID, clientCert, clientKey, caCert)
if err != nil {
@@ -201,7 +202,7 @@ func (bs bootstrapService) UpdateCert(ctx context.Context, token, thingID, clien
func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id string, connections []string) error {
owner, err := bs.identify(ctx, token)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
@@ -255,7 +256,7 @@ func (bs bootstrapService) UpdateConnections(ctx context.Context, token, id stri
func (bs bootstrapService) List(ctx context.Context, token string, filter Filter, offset, limit uint64) (ConfigsPage, error) {
owner, err := bs.identify(ctx, token)
if err != nil {
return ConfigsPage{}, err
return ConfigsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return bs.configs.RetrieveAll(ctx, owner, filter, offset, limit), nil
@@ -264,7 +265,7 @@ func (bs bootstrapService) List(ctx context.Context, token string, filter Filter
func (bs bootstrapService) Remove(ctx context.Context, token, id string) error {
owner, err := bs.identify(ctx, token)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := bs.configs.Remove(ctx, owner, id); err != nil {
return errors.Wrap(errRemoveBootstrap, err)
@@ -294,7 +295,7 @@ func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalI
func (bs bootstrapService) ChangeState(ctx context.Context, token, id string, state State) error {
owner, err := bs.identify(ctx, token)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthentication, err)
}
cfg, err := bs.configs.RetrieveByID(ctx, owner, id)
@@ -367,7 +368,7 @@ func (bs bootstrapService) identify(ctx context.Context, token string) (string,
res, err := bs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return "", errors.ErrAuthentication
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
return res.GetId(), nil
+13 -9
View File
@@ -10,6 +10,8 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/certs/pki"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
)
@@ -21,6 +23,8 @@ var (
ErrFailedCertRevocation = errors.New("failed to revoke certificate")
ErrFailedToRemoveCertFromDB = errors.New("failed to remove cert serial from db")
ErrFailedReadFromPKI = errors.New("failed to read certificate from PKI")
)
var _ Service = (*certsService)(nil)
@@ -82,7 +86,7 @@ type Cert struct {
func (cs *certsService) IssueCert(ctx context.Context, token, thingID string, ttl string) (Cert, error) {
owner, err := cs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Cert{}, err
return Cert{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
thing, err := cs.sdk.Thing(thingID, token)
@@ -115,7 +119,7 @@ func (cs *certsService) RevokeCert(ctx context.Context, token, thingID string) (
var revoke Revoke
u, err := cs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return revoke, err
return revoke, errors.Wrap(svcerr.ErrAuthentication, err)
}
thing, err := cs.sdk.Thing(thingID, token)
if err != nil {
@@ -145,18 +149,18 @@ func (cs *certsService) RevokeCert(ctx context.Context, token, thingID string) (
func (cs *certsService) ListCerts(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) {
u, err := cs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Page{}, err
return Page{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
cp, err := cs.certsRepo.RetrieveByThing(ctx, u.GetId(), thingID, offset, limit)
if err != nil {
return Page{}, err
return Page{}, errors.Wrap(repoerr.ErrNotFound, err)
}
for i, cert := range cp.Certs {
vcert, err := cs.pki.Read(cert.Serial)
if err != nil {
return Page{}, err
return Page{}, errors.Wrap(ErrFailedReadFromPKI, err)
}
cp.Certs[i].ClientCert = vcert.ClientCert
cp.Certs[i].ClientKey = vcert.ClientKey
@@ -168,7 +172,7 @@ func (cs *certsService) ListCerts(ctx context.Context, token, thingID string, of
func (cs *certsService) ListSerials(ctx context.Context, token, thingID string, offset, limit uint64) (Page, error) {
u, err := cs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Page{}, err
return Page{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
return cs.certsRepo.RetrieveByThing(ctx, u.GetId(), thingID, offset, limit)
@@ -177,17 +181,17 @@ func (cs *certsService) ListSerials(ctx context.Context, token, thingID string,
func (cs *certsService) ViewCert(ctx context.Context, token, serialID string) (Cert, error) {
u, err := cs.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Cert{}, err
return Cert{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
cert, err := cs.certsRepo.RetrieveBySerial(ctx, u.GetId(), serialID)
if err != nil {
return Cert{}, err
return Cert{}, errors.Wrap(repoerr.ErrNotFound, err)
}
vcert, err := cs.pki.Read(serialID)
if err != nil {
return Cert{}, err
return Cert{}, errors.Wrap(ErrFailedReadFromPKI, err)
}
c := Cert{
+9 -8
View File
@@ -13,6 +13,7 @@ import (
"github.com/absmach/magistrala/internal/postgres"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/gofrs/uuid"
)
@@ -104,7 +105,7 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
w.Header().Set("Content-Type", ContentType)
switch {
case errors.Contains(err, apiutil.ErrInvalidSecret),
errors.Contains(err, errors.ErrMalformedEntity),
errors.Contains(err, svcerr.ErrMalformedEntity),
errors.Contains(err, apiutil.ErrMissingID),
errors.Contains(err, apiutil.ErrEmptyList),
errors.Contains(err, apiutil.ErrMissingMemberType),
@@ -113,20 +114,20 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
w.WriteHeader(http.StatusBadRequest)
case errors.Contains(err, errors.ErrAuthentication):
w.WriteHeader(http.StatusUnauthorized)
case errors.Contains(err, errors.ErrNotFound):
case errors.Contains(err, svcerr.ErrNotFound):
w.WriteHeader(http.StatusNotFound)
case errors.Contains(err, errors.ErrConflict):
case errors.Contains(err, svcerr.ErrConflict):
w.WriteHeader(http.StatusConflict)
case errors.Contains(err, errors.ErrAuthorization):
case errors.Contains(err, svcerr.ErrAuthorization):
w.WriteHeader(http.StatusForbidden)
case errors.Contains(err, postgres.ErrMemberAlreadyAssigned):
w.WriteHeader(http.StatusConflict)
case errors.Contains(err, apiutil.ErrUnsupportedContentType):
w.WriteHeader(http.StatusUnsupportedMediaType)
case errors.Contains(err, errors.ErrCreateEntity),
errors.Contains(err, errors.ErrUpdateEntity),
errors.Contains(err, errors.ErrViewEntity),
errors.Contains(err, errors.ErrRemoveEntity):
case errors.Contains(err, svcerr.ErrCreateEntity),
errors.Contains(err, svcerr.ErrUpdateEntity),
errors.Contains(err, svcerr.ErrViewEntity),
errors.Contains(err, svcerr.ErrRemoveEntity):
w.WriteHeader(http.StatusInternalServerError)
default:
w.WriteHeader(http.StatusInternalServerError)
+7 -7
View File
@@ -7,7 +7,7 @@ import (
"context"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
mggroups "github.com/absmach/magistrala/pkg/groups"
"github.com/stretchr/testify/mock"
)
@@ -24,11 +24,11 @@ func (m *Repository) ChangeStatus(ctx context.Context, group mggroups.Group) (mg
ret := m.Called(ctx, group)
if group.ID == WrongID {
return mggroups.Group{}, errors.ErrNotFound
return mggroups.Group{}, repoerror.ErrNotFound
}
if group.Status != mgclients.EnabledStatus && group.Status != mgclients.DisabledStatus {
return mggroups.Group{}, errors.ErrMalformedEntity
return mggroups.Group{}, repoerror.ErrMalformedEntity
}
return ret.Get(0).(mggroups.Group), ret.Error(1)
@@ -56,7 +56,7 @@ func (m *Repository) RetrieveByID(ctx context.Context, id string) (mggroups.Grou
ret := m.Called(ctx, id)
if id == WrongID {
return mggroups.Group{}, errors.ErrNotFound
return mggroups.Group{}, repoerror.ErrNotFound
}
return ret.Get(0).(mggroups.Group), ret.Error(1)
@@ -66,11 +66,11 @@ func (m *Repository) Save(ctx context.Context, g mggroups.Group) (mggroups.Group
ret := m.Called(ctx, g)
if g.Parent == WrongID {
return mggroups.Group{}, errors.ErrCreateEntity
return mggroups.Group{}, repoerror.ErrCreateEntity
}
if g.Owner == WrongID {
return mggroups.Group{}, errors.ErrCreateEntity
return mggroups.Group{}, repoerror.ErrCreateEntity
}
return g, ret.Error(1)
@@ -80,7 +80,7 @@ func (m *Repository) Update(ctx context.Context, g mggroups.Group) (mggroups.Gro
ret := m.Called(ctx, g)
if g.ID == WrongID {
return mggroups.Group{}, errors.ErrNotFound
return mggroups.Group{}, repoerror.ErrNotFound
}
return ret.Get(0).(mggroups.Group), ret.Error(1)
+17 -16
View File
@@ -14,6 +14,7 @@ import (
"github.com/absmach/magistrala/internal/postgres"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
mggroups "github.com/absmach/magistrala/pkg/groups"
"github.com/jmoiron/sqlx"
)
@@ -42,7 +43,7 @@ func (repo groupRepository) Save(ctx context.Context, g mggroups.Group) (mggroup
}
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
return mggroups.Group{}, postgres.HandleError(err, errors.ErrCreateEntity)
return mggroups.Group{}, postgres.HandleError(repoerror.ErrCreateEntity, err)
}
defer row.Close()
@@ -77,21 +78,21 @@ func (repo groupRepository) Update(ctx context.Context, g mggroups.Group) (mggro
dbu, err := toDBGroup(g)
if err != nil {
return mggroups.Group{}, errors.Wrap(errors.ErrUpdateEntity, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrUpdateEntity, err)
}
row, err := repo.db.NamedQueryContext(ctx, q, dbu)
if err != nil {
return mggroups.Group{}, postgres.HandleError(err, errors.ErrUpdateEntity)
return mggroups.Group{}, postgres.HandleError(repoerror.ErrUpdateEntity, err)
}
defer row.Close()
if ok := row.Next(); !ok {
return mggroups.Group{}, errors.Wrap(errors.ErrNotFound, row.Err())
return mggroups.Group{}, errors.Wrap(repoerror.ErrNotFound, row.Err())
}
dbu = dbGroup{}
if err := row.StructScan(&dbu); err != nil {
return mggroups.Group{}, errors.Wrap(err, errors.ErrUpdateEntity)
return mggroups.Group{}, errors.Wrap(err, repoerror.ErrUpdateEntity)
}
return toGroup(dbu)
}
@@ -102,20 +103,20 @@ func (repo groupRepository) ChangeStatus(ctx context.Context, group mggroups.Gro
dbg, err := toDBGroup(group)
if err != nil {
return mggroups.Group{}, errors.Wrap(errors.ErrUpdateEntity, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrUpdateEntity, err)
}
row, err := repo.db.NamedQueryContext(ctx, qc, dbg)
if err != nil {
return mggroups.Group{}, postgres.HandleError(err, errors.ErrUpdateEntity)
return mggroups.Group{}, postgres.HandleError(repoerror.ErrUpdateEntity, err)
}
defer row.Close()
if ok := row.Next(); !ok {
return mggroups.Group{}, errors.Wrap(errors.ErrNotFound, row.Err())
return mggroups.Group{}, errors.Wrap(repoerror.ErrNotFound, row.Err())
}
dbg = dbGroup{}
if err := row.StructScan(&dbg); err != nil {
return mggroups.Group{}, errors.Wrap(err, errors.ErrUpdateEntity)
return mggroups.Group{}, errors.Wrap(err, repoerror.ErrUpdateEntity)
}
return toGroup(dbg)
@@ -132,16 +133,16 @@ func (repo groupRepository) RetrieveByID(ctx context.Context, id string) (mggrou
row, err := repo.db.NamedQueryContext(ctx, q, dbg)
if err != nil {
if err == sql.ErrNoRows {
return mggroups.Group{}, errors.Wrap(errors.ErrNotFound, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrNotFound, err)
}
return mggroups.Group{}, errors.Wrap(errors.ErrViewEntity, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrViewEntity, err)
}
defer row.Close()
row.Next()
dbg = dbGroup{}
if err := row.StructScan(&dbg); err != nil {
return mggroups.Group{}, errors.Wrap(errors.ErrNotFound, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrNotFound, err)
}
return toGroup(dbg)
@@ -198,7 +199,7 @@ func (repo groupRepository) RetrieveAll(ctx context.Context, gm mggroups.Page) (
func (repo groupRepository) RetrieveByIDs(ctx context.Context, gm mggroups.Page, ids ...string) (mggroups.Page, error) {
var q string
if len(ids) <= 0 {
return mggroups.Page{}, errors.ErrNotFound
return mggroups.Page{}, repoerror.ErrNotFound
}
query, err := buildQuery(gm, ids...)
if err != nil {
@@ -266,7 +267,7 @@ func (repo groupRepository) AssignParentGroup(ctx context.Context, parentGroupID
row, err := repo.db.QueryContext(ctx, query)
if err != nil {
return postgres.HandleError(err, errors.ErrUpdateEntity)
return postgres.HandleError(repoerror.ErrUpdateEntity, err)
}
defer row.Close()
@@ -293,7 +294,7 @@ func (repo groupRepository) UnassignParentGroup(ctx context.Context, parentGroup
row, err := repo.db.QueryContext(ctx, query)
if err != nil {
return postgres.HandleError(err, errors.ErrUpdateEntity)
return postgres.HandleError(repoerror.ErrUpdateEntity, err)
}
defer row.Close()
@@ -400,7 +401,7 @@ func toGroup(g dbGroup) (mggroups.Group, error) {
var metadata mgclients.Metadata
if g.Metadata != nil {
if err := json.Unmarshal(g.Metadata, &metadata); err != nil {
return mggroups.Group{}, errors.Wrap(errors.ErrMalformedEntity, err)
return mggroups.Group{}, errors.Wrap(repoerror.ErrMalformedEntity, err)
}
}
var parentID string
+5 -4
View File
@@ -5,6 +5,7 @@ package postgres
import (
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/jackc/pgx/v5/pgconn"
)
@@ -17,16 +18,16 @@ const (
errInvalid = "22P02" // invalid_text_representation
)
func HandleError(err, wrapper error) error {
func HandleError(wrapper, err error) error {
pqErr, ok := err.(*pgconn.PgError)
if ok {
switch pqErr.Code {
case errDuplicate:
return errors.Wrap(errors.ErrConflict, err)
return errors.Wrap(repoerror.ErrConflict, err)
case errInvalid, errTruncation:
return errors.Wrap(errors.ErrMalformedEntity, err)
return errors.Wrap(repoerror.ErrMalformedEntity, err)
case errFK:
return errors.Wrap(errors.ErrCreateEntity, err)
return errors.Wrap(repoerror.ErrCreateEntity, err)
}
}
+21 -20
View File
@@ -14,6 +14,7 @@ import (
"github.com/absmach/magistrala/internal/postgres"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/pkg/groups"
"github.com/jackc/pgtype"
)
@@ -94,16 +95,16 @@ func (repo ClientRepository) RetrieveByID(ctx context.Context, id string) (clien
row, err := repo.DB.NamedQueryContext(ctx, q, dbc)
if err != nil {
if err == sql.ErrNoRows {
return clients.Client{}, errors.Wrap(errors.ErrNotFound, err)
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return clients.Client{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer row.Close()
row.Next()
dbc = DBClient{}
if err := row.StructScan(&dbc); err != nil {
return clients.Client{}, errors.Wrap(errors.ErrNotFound, err)
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return ToClient(dbc)
@@ -121,16 +122,16 @@ func (repo ClientRepository) RetrieveByIdentity(ctx context.Context, identity st
row, err := repo.DB.NamedQueryContext(ctx, q, dbc)
if err != nil {
if err == sql.ErrNoRows {
return clients.Client{}, errors.Wrap(errors.ErrNotFound, err)
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return clients.Client{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer row.Close()
row.Next()
dbc = DBClient{}
if err := row.StructScan(&dbc); err != nil {
return clients.Client{}, errors.Wrap(errors.ErrNotFound, err)
return clients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return ToClient(dbc)
@@ -139,7 +140,7 @@ func (repo ClientRepository) RetrieveByIdentity(ctx context.Context, identity st
func (repo ClientRepository) RetrieveAll(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
query, err := pageQuery(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.owner_id, '') AS owner_id, c.status,
@@ -159,7 +160,7 @@ func (repo ClientRepository) RetrieveAll(ctx context.Context, pm clients.Page) (
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -173,7 +174,7 @@ func (repo ClientRepository) RetrieveAll(ctx context.Context, pm clients.Page) (
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -191,7 +192,7 @@ func (repo ClientRepository) RetrieveAll(ctx context.Context, pm clients.Page) (
func (repo ClientRepository) RetrieveAllBasicInfo(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) {
query, err := pageQuery(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query)
@@ -210,7 +211,7 @@ func (repo ClientRepository) RetrieveAllBasicInfo(ctx context.Context, pm client
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -224,7 +225,7 @@ func (repo ClientRepository) RetrieveAllBasicInfo(ctx context.Context, pm client
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -247,7 +248,7 @@ func (repo ClientRepository) RetrieveAllByIDs(ctx context.Context, pm clients.Pa
}
query, err := pageQuery(pm)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.owner_id, '') AS owner_id, c.status,
@@ -267,7 +268,7 @@ func (repo ClientRepository) RetrieveAllByIDs(ctx context.Context, pm clients.Pa
for rows.Next() {
dbc := DBClient{}
if err := rows.StructScan(&dbc); err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
c, err := ToClient(dbc)
@@ -281,7 +282,7 @@ func (repo ClientRepository) RetrieveAllByIDs(ctx context.Context, pm clients.Pa
total, err := postgres.Total(ctx, repo.DB, cq, dbPage)
if err != nil {
return clients.ClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := clients.ClientsPage{
@@ -300,12 +301,12 @@ func (repo ClientRepository) RetrieveAllByIDs(ctx context.Context, pm clients.Pa
func (repo ClientRepository) update(ctx context.Context, client clients.Client, query string) (clients.Client, error) {
dbc, err := ToDBClient(client)
if err != nil {
return clients.Client{}, errors.Wrap(errors.ErrUpdateEntity, err)
return clients.Client{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
row, err := repo.DB.NamedQueryContext(ctx, query, dbc)
if err != nil {
return clients.Client{}, postgres.HandleError(err, errors.ErrUpdateEntity)
return clients.Client{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
@@ -341,7 +342,7 @@ func ToDBClient(c clients.Client) (DBClient, error) {
if len(c.Metadata) > 0 {
b, err := json.Marshal(c.Metadata)
if err != nil {
return DBClient{}, errors.Wrap(errors.ErrMalformedEntity, err)
return DBClient{}, errors.Wrap(repoerr.ErrMalformedEntity, err)
}
data = b
}
@@ -422,7 +423,7 @@ func ToClient(c DBClient) (clients.Client, error) {
func toDBClientsPage(pm clients.Page) (dbClientsPage, error) {
_, data, err := postgres.CreateMetadataQuery("", pm.Metadata)
if err != nil {
return dbClientsPage{}, errors.Wrap(errors.ErrViewEntity, err)
return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
var role clients.Role
if pm.Role != nil {
@@ -459,7 +460,7 @@ type dbClientsPage struct {
func pageQuery(pm clients.Page) (string, error) {
mq, _, err := postgres.CreateMetadataQuery("", pm.Metadata)
if err != nil {
return "", errors.Wrap(errors.ErrViewEntity, err)
return "", errors.Wrap(repoerr.ErrViewEntity, err)
}
var query []string
var emq string
+48
View File
@@ -0,0 +1,48 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package repository
import "github.com/absmach/magistrala/pkg/errors"
// Wrapper for Repository errors.
var (
// ErrMalformedEntity indicates a malformed entity specification.
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrNotFound indicates a non-existent entity request.
ErrNotFound = errors.New("entity not found")
// ErrConflict indicates that entity already exists.
ErrConflict = errors.New("entity already exists")
// ErrCreateEntity indicates error in creating entity or entities.
ErrCreateEntity = errors.New("failed to create entity in the db")
// ErrViewEntity indicates error in viewing entity or entities.
ErrViewEntity = errors.New("view entity failed")
// ErrUpdateEntity indicates error in updating entity or entities.
ErrUpdateEntity = errors.New("update entity failed")
// ErrRemoveEntity indicates error in removing entity.
ErrRemoveEntity = errors.New("failed to remove entity")
// ErrScanMetadata indicates problem with metadata in db.
ErrScanMetadata = errors.New("failed to scan metadata in db")
// ErrWrongSecret indicates a wrong secret was provided.
ErrWrongSecret = errors.New("wrong secret")
// ErrLogin indicates wrong login credentials.
ErrLogin = errors.New("invalid user id or secret")
// ErrFailedOpDB indicates a failure in a database operation.
ErrFailedOpDB = errors.New("operation on db element failed")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = errors.New("failed to rollback transaction")
// ErrMissingSecret indicates missing secret.
ErrMissingSecret = errors.New("missing secret")
)
+45
View File
@@ -0,0 +1,45 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package service
import "github.com/absmach/magistrala/pkg/errors"
// Wrapper for Service errors.
var (
// ErrAuthentication indicates failure occurred while authenticating the entity.
ErrAuthentication = errors.New("authentication error")
// ErrAuthorization indicates failure occurred while authorizing the entity.
ErrAuthorization = errors.New("failed to perform authorization over the entity")
// ErrMalformedEntity indicates a malformed entity specification.
ErrMalformedEntity = errors.New("malformed entity specification")
// ErrNotFound indicates a non-existent entity request.
ErrNotFound = errors.New("entity not found")
// ErrConflict indicates that entity already exists.
ErrConflict = errors.New("entity already exists")
// ErrCreateEntity indicates error in creating entity or entities.
ErrCreateEntity = errors.New("failed to create entity in the db")
// ErrRemoveEntity indicates error in removing entity.
ErrRemoveEntity = errors.New("failed to remove entity")
// ErrViewEntity indicates error in viewing entity or entities.
ErrViewEntity = errors.New("view entity failed")
// ErrUpdateEntity indicates error in updating entity or entities.
ErrUpdateEntity = errors.New("update entity failed")
// ErrUniqueID indicates an error in generating a unique ID.
ErrUniqueID = errors.New("failed to generate unique identifier")
// ErrInvalidStatus indicates an invalid status.
ErrInvalidStatus = errors.New("Invalid status")
// ErrInvalidRole indicates that an invalid role.
ErrInvalidRole = errors.New("invalid client role")
)
+4 -1
View File
@@ -3,6 +3,8 @@
package errors
import "errors"
var (
// ErrAuthentication indicates failure occurred while authenticating the entity.
ErrAuthentication = New("failed to perform authentication over the entity")
@@ -43,5 +45,6 @@ var (
// ErrLogin indicates wrong login credentials.
ErrLogin = New("invalid user id or secret")
ErrUnsupportedContentType = New("invalid content type")
// ErrUnsupportedContentType indicates invalid content type.
ErrUnsupportedContentType = errors.New("invalid content type")
)
+3 -2
View File
@@ -18,6 +18,7 @@ import (
mglog "github.com/absmach/magistrala/logger"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
mggroups "github.com/absmach/magistrala/pkg/groups"
sdk "github.com/absmach/magistrala/pkg/sdk/go"
"github.com/absmach/magistrala/things"
@@ -703,7 +704,7 @@ func TestEnableChannel(t *testing.T) {
repoCall1 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil)
repoCall2 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(nil)
_, err := mgsdk.EnableChannel("wrongID", adminToken)
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrEnableGroup, errors.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Enable channel with wrong id: expected %v got %v", errors.ErrNotFound, err))
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrEnableGroup, repoerror.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Enable channel with wrong id: expected %v got %v", repoerror.ErrNotFound, err))
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID")
assert.True(t, ok, "RetrieveByID was not called on enabling channel")
repoCall1.Unset()
@@ -753,7 +754,7 @@ func TestDisableChannel(t *testing.T) {
repoCall1 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval)
repoCall2 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil)
_, err := mgsdk.DisableChannel("wrongID", adminToken)
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrDisableGroup, errors.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Disable channel with wrong id: expected %v got %v", errors.ErrNotFound, err))
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrDisableGroup, repoerror.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Disable channel with wrong id: expected %v got %v", repoerror.ErrNotFound, err))
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID")
assert.True(t, ok, "Memberships was not called on disabling channel with wrong id")
repoCall1.Unset()
+3 -2
View File
@@ -18,6 +18,7 @@ import (
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
mggroups "github.com/absmach/magistrala/pkg/groups"
sdk "github.com/absmach/magistrala/pkg/sdk/go"
"github.com/absmach/magistrala/users"
@@ -758,7 +759,7 @@ func TestEnableGroup(t *testing.T) {
repoCall1 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil)
repoCall2 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval)
_, err := mgsdk.EnableGroup("wrongID", validToken)
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Enable group with wrong id: expected %v got %v", errors.ErrNotFound, err))
assert.Equal(t, err, errors.NewSDKErrorWithStatus(repoerror.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Enable group with wrong id: expected %v got %v", repoerror.ErrNotFound, err))
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID")
assert.True(t, ok, "RetrieveByID was not called on enabling group")
repoCall1.Unset()
@@ -808,7 +809,7 @@ func TestDisableGroup(t *testing.T) {
repoCall1 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval)
repoCall2 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil)
_, err := mgsdk.DisableGroup("wrongID", validToken)
assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Disable group with wrong id: expected %v got %v", errors.ErrNotFound, err))
assert.Equal(t, err, errors.NewSDKErrorWithStatus(repoerror.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Disable group with wrong id: expected %v got %v", repoerror.ErrNotFound, err))
ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID")
assert.True(t, ok, "Memberships was not called on disabling group with wrong id")
repoCall1.Unset()
+3 -3
View File
@@ -117,7 +117,7 @@ func (ps *provisionService) Provision(token, name, externalID, externalKey strin
token, err = ps.createTokenIfEmpty(token)
if err != nil {
return res, err
return res, errors.Wrap(ErrFailedToCreateToken, err)
}
if len(ps.conf.Things) == 0 {
@@ -162,7 +162,7 @@ func (ps *provisionService) Provision(token, name, externalID, externalKey strin
}
ch, err := ps.sdk.CreateChannel(ch, token)
if err != nil {
return res, err
return res, errors.Wrap(ErrFailedChannelCreation, err)
}
ch, err = ps.sdk.Channel(ch.ID, token)
if err != nil {
@@ -257,7 +257,7 @@ func (ps *provisionService) Provision(token, name, externalID, externalKey strin
func (ps *provisionService) Cert(token, thingID, ttl string) (string, string, error) {
token, err := ps.createTokenIfEmpty(token)
if err != nil {
return "", "", err
return "", "", errors.Wrap(ErrFailedToCreateToken, err)
}
th, err := ps.sdk.Thing(thingID, token)
+3 -3
View File
@@ -7,7 +7,7 @@ import (
"context"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
@@ -21,10 +21,10 @@ type Service struct {
func (m *Service) Authorize(ctx context.Context, in *magistrala.AuthorizeReq, opts ...grpc.CallOption) (*magistrala.AuthorizeRes, error) {
ret := m.Called(ctx, in)
if in.GetSubject() == WrongID || in.GetSubject() == "" {
return &magistrala.AuthorizeRes{}, errors.ErrAuthorization
return &magistrala.AuthorizeRes{}, svcerr.ErrAuthorization
}
if in.GetObject() == WrongID || in.GetObject() == "" {
return &magistrala.AuthorizeRes{}, errors.ErrAuthorization
return &magistrala.AuthorizeRes{}, svcerr.ErrAuthorization
}
return ret.Get(0).(*magistrala.AuthorizeRes), ret.Error(1)
+14 -14
View File
@@ -7,7 +7,7 @@ import (
"context"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/stretchr/testify/mock"
)
@@ -28,11 +28,11 @@ func (m *Repository) ChangeStatus(ctx context.Context, client mgclients.Client)
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Status != mgclients.EnabledStatus && client.Status != mgclients.DisabledStatus {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -41,7 +41,7 @@ func (m *Repository) ChangeStatus(ctx context.Context, client mgclients.Client)
func (m *Repository) Members(ctx context.Context, groupID string, pm mgclients.Page) (mgclients.MembersPage, error) {
ret := m.Called(ctx, groupID, pm)
if groupID == WrongID {
return mgclients.MembersPage{}, errors.ErrNotFound
return mgclients.MembersPage{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.MembersPage), ret.Error(1)
@@ -63,7 +63,7 @@ func (m *Repository) RetrieveByID(ctx context.Context, id string) (mgclients.Cli
ret := m.Called(ctx, id)
if id == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -73,7 +73,7 @@ func (m *Repository) RetrieveBySecret(ctx context.Context, secret string) (mgcli
ret := m.Called(ctx, secret)
if secret == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -83,7 +83,7 @@ func (m *Repository) Save(ctx context.Context, clis ...mgclients.Client) ([]mgcl
ret := m.Called(ctx, clis)
for _, cli := range clis {
if cli.Owner == WrongID {
return []mgclients.Client{}, errors.ErrMalformedEntity
return []mgclients.Client{}, repoerr.ErrMalformedEntity
}
}
return clis, ret.Error(1)
@@ -93,7 +93,7 @@ func (m *Repository) Update(ctx context.Context, client mgclients.Client) (mgcli
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
}
@@ -102,10 +102,10 @@ func (m *Repository) UpdateIdentity(ctx context.Context, client mgclients.Client
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Credentials.Identity == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -115,10 +115,10 @@ func (m *Repository) UpdateSecret(ctx context.Context, client mgclients.Client)
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Credentials.Secret == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -128,7 +128,7 @@ func (m *Repository) UpdateTags(ctx context.Context, client mgclients.Client) (m
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -138,7 +138,7 @@ func (m *Repository) UpdateOwner(ctx context.Context, client mgclients.Client) (
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
+2 -2
View File
@@ -7,7 +7,7 @@ import (
"context"
"sync"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/things"
)
@@ -37,7 +37,7 @@ func (tcm *clientCacheMock) ID(_ context.Context, key string) (string, error) {
id, ok := tcm.things[key]
if !ok {
return "", errors.ErrNotFound
return "", repoerr.ErrNotFound
}
return id, nil
+10 -9
View File
@@ -12,6 +12,7 @@ import (
mgclients "github.com/absmach/magistrala/pkg/clients"
pgclients "github.com/absmach/magistrala/pkg/clients/postgres"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
)
var _ mgclients.Repository = (*clientRepo)(nil)
@@ -42,7 +43,7 @@ func NewRepository(db postgres.Database) Repository {
func (repo clientRepo) Save(ctx context.Context, cs ...mgclients.Client) ([]mgclients.Client, error) {
tx, err := repo.ClientRepository.DB.BeginTxx(ctx, nil)
if err != nil {
return []mgclients.Client{}, errors.Wrap(errors.ErrCreateEntity, err)
return []mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
var clients []mgclients.Client
@@ -53,32 +54,32 @@ func (repo clientRepo) Save(ctx context.Context, cs ...mgclients.Client) ([]mgcl
dbcli, err := pgclients.ToDBClient(cli)
if err != nil {
return []mgclients.Client{}, errors.Wrap(errors.ErrCreateEntity, err)
return []mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
row, err := repo.ClientRepository.DB.NamedQueryContext(ctx, q, dbcli)
if err != nil {
if err := tx.Rollback(); err != nil {
return []mgclients.Client{}, postgres.HandleError(err, errors.ErrCreateEntity)
return []mgclients.Client{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
}
return []mgclients.Client{}, errors.Wrap(errors.ErrCreateEntity, err)
return []mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
defer row.Close()
row.Next()
dbcli = pgclients.DBClient{}
if err := row.StructScan(&dbcli); err != nil {
return []mgclients.Client{}, err
return []mgclients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
client, err := pgclients.ToClient(dbcli)
if err != nil {
return []mgclients.Client{}, err
return []mgclients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
clients = append(clients, client)
}
if err = tx.Commit(); err != nil {
return []mgclients.Client{}, errors.Wrap(errors.ErrCreateEntity, err)
return []mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
return clients, nil
@@ -95,9 +96,9 @@ func (repo clientRepo) RetrieveBySecret(ctx context.Context, key string) (mgclie
if err := repo.DB.QueryRowxContext(ctx, q, key).StructScan(&dbc); err != nil {
if err == sql.ErrNoRows {
return mgclients.Client{}, errors.Wrap(errors.ErrNotFound, err)
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return mgclients.Client{}, errors.Wrap(errors.ErrViewEntity, err)
return mgclients.Client{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return pgclients.ToClient(dbc)
+42 -36
View File
@@ -8,13 +8,19 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/internal/apiutil"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mggroups "github.com/absmach/magistrala/pkg/groups"
"github.com/absmach/magistrala/things/postgres"
)
var (
errAddPolicies = errors.New("failed to add policies")
errRemovePolicies = errors.New("failed to remove the policies")
)
type service struct {
auth magistrala.AuthServiceClient
clients postgres.Repository
@@ -37,7 +43,7 @@ func NewService(uauth magistrala.AuthServiceClient, c postgres.Repository, grepo
func (svc service) Authorize(ctx context.Context, req *magistrala.AuthorizeReq) (string, error) {
thingID, err := svc.Identify(ctx, req.GetSubject())
if err != nil {
return "", errors.ErrAuthentication
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
r := &magistrala.AuthorizeReq{
@@ -49,10 +55,10 @@ func (svc service) Authorize(ctx context.Context, req *magistrala.AuthorizeReq)
}
resp, err := svc.auth.Authorize(ctx, r)
if err != nil {
return "", err
return "", errors.Wrap(errors.ErrAuthorization, err)
}
if !resp.GetAuthorized() {
return "", errors.ErrAuthorization
return "", errors.Wrap(errors.ErrAuthorization, err)
}
return thingID, nil
@@ -61,26 +67,26 @@ func (svc service) Authorize(ctx context.Context, req *magistrala.AuthorizeReq)
func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclients.Client) ([]mgclients.Client, error) {
user, err := svc.identify(ctx, token)
if err != nil {
return []mgclients.Client{}, errors.Wrap(errors.ErrAuthorization, err)
return []mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
var clients []mgclients.Client
for _, c := range cls {
if c.ID == "" {
clientID, err := svc.idProvider.ID()
if err != nil {
return []mgclients.Client{}, err
return []mgclients.Client{}, errors.Wrap(svcerr.ErrUniqueID, err)
}
c.ID = clientID
}
if c.Credentials.Secret == "" {
key, err := svc.idProvider.ID()
if err != nil {
return []mgclients.Client{}, err
return []mgclients.Client{}, errors.Wrap(svcerr.ErrUniqueID, err)
}
c.Credentials.Secret = key
}
if c.Status != mgclients.DisabledStatus && c.Status != mgclients.EnabledStatus {
return []mgclients.Client{}, apiutil.ErrInvalidStatus
return []mgclients.Client{}, svcerr.ErrInvalidStatus
}
c.CreatedAt = time.Now()
clients = append(clients, c)
@@ -88,7 +94,7 @@ func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclie
saved, err := svc.clients.Save(ctx, clients...)
if err != nil {
return nil, err
return nil, errors.Wrap(repoerr.ErrCreateEntity, err)
}
policies := magistrala.AddPoliciesReq{}
@@ -112,7 +118,7 @@ func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclie
})
}
if _, err := svc.auth.AddPolicies(ctx, &policies); err != nil {
return nil, err
return nil, errors.Wrap(errAddPolicies, err)
}
return saved, nil
@@ -121,7 +127,7 @@ func (svc service) CreateThings(ctx context.Context, token string, cls ...mgclie
func (svc service) ViewClient(ctx context.Context, token string, id string) (mgclients.Client, error) {
_, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.ViewPermission, auth.ThingType, id)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return svc.clients.RetrieveByID(ctx, id)
@@ -132,7 +138,7 @@ func (svc service) ListClients(ctx context.Context, token string, reqUserID stri
res, err := svc.identify(ctx, token)
if err != nil {
return mgclients.ClientsPage{}, err
return mgclients.ClientsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
switch {
@@ -143,16 +149,16 @@ func (svc service) ListClients(ctx context.Context, token string, reqUserID stri
}
rtids, err := svc.listClientIDs(ctx, auth.EncodeDomainUserID(res.GetDomainId(), reqUserID), pm.Permission)
if err != nil {
return mgclients.ClientsPage{}, err
return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
ids, err = svc.filterAllowedThingIDs(ctx, res.GetId(), pm.Permission, rtids)
if err != nil {
return mgclients.ClientsPage{}, err
return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
default:
ids, err = svc.listClientIDs(ctx, res.GetId(), pm.Permission)
if err != nil {
return mgclients.ClientsPage{}, err
return mgclients.ClientsPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
}
@@ -175,7 +181,7 @@ func (svc service) listClientIDs(ctx context.Context, userID, permission string)
ObjectType: auth.ThingType,
})
if err != nil {
return nil, err
return nil, errors.Wrap(repoerr.ErrNotFound, err)
}
return tids.Policies, nil
}
@@ -189,7 +195,7 @@ func (svc service) filterAllowedThingIDs(ctx context.Context, userID, permission
ObjectType: auth.ThingType,
})
if err != nil {
return nil, err
return nil, errors.Wrap(repoerr.ErrNotFound, err)
}
for _, thingID := range thingIDs {
for _, tid := range tids.Policies {
@@ -204,7 +210,7 @@ func (svc service) filterAllowedThingIDs(ctx context.Context, userID, permission
func (svc service) UpdateClient(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
userID, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.EditPermission, auth.ThingType, cli.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
client := mgclients.Client{
@@ -220,7 +226,7 @@ func (svc service) UpdateClient(ctx context.Context, token string, cli mgclients
func (svc service) UpdateClientTags(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
userID, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.EditPermission, auth.ThingType, cli.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
client := mgclients.Client{
@@ -235,7 +241,7 @@ func (svc service) UpdateClientTags(ctx context.Context, token string, cli mgcli
func (svc service) UpdateClientSecret(ctx context.Context, token, id, key string) (mgclients.Client, error) {
userID, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.EditPermission, auth.ThingType, id)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
client := mgclients.Client{
@@ -253,7 +259,7 @@ func (svc service) UpdateClientSecret(ctx context.Context, token, id, key string
func (svc service) UpdateClientOwner(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
userID, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.EditPermission, auth.ThingType, cli.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
client := mgclients.Client{
@@ -292,7 +298,7 @@ func (svc service) DisableClient(ctx context.Context, token, id string) (mgclien
}
if err := svc.clientCache.Remove(ctx, client.ID); err != nil {
return client, err
return client, errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return client, nil
@@ -304,7 +310,7 @@ func (svc service) Share(ctx context.Context, token, id, relation string, userid
return nil
}
if _, err := svc.authorize(ctx, auth.UserType, auth.UsersKind, user.GetId(), auth.DeletePermission, auth.ThingType, id); err != nil {
return err
return errors.Wrap(svcerr.ErrAuthorization, err)
}
policies := magistrala.AddPoliciesReq{}
@@ -319,10 +325,10 @@ func (svc service) Share(ctx context.Context, token, id, relation string, userid
}
res, err := svc.auth.AddPolicies(ctx, &policies)
if err != nil {
return err
return errors.Wrap(errAddPolicies, err)
}
if !res.Authorized {
return errors.ErrAuthorization
return err
}
return nil
}
@@ -333,7 +339,7 @@ func (svc service) Unshare(ctx context.Context, token, id, relation string, user
return nil
}
if _, err := svc.authorize(ctx, auth.UserType, auth.UsersKind, user.GetId(), auth.DeletePermission, auth.ThingType, id); err != nil {
return err
return errors.Wrap(svcerr.ErrAuthorization, err)
}
policies := magistrala.DeletePoliciesReq{}
@@ -348,10 +354,10 @@ func (svc service) Unshare(ctx context.Context, token, id, relation string, user
}
res, err := svc.auth.DeletePolicies(ctx, &policies)
if err != nil {
return err
return errors.Wrap(errRemovePolicies, err)
}
if !res.Deleted {
return errors.ErrAuthorization
return err
}
return nil
}
@@ -359,11 +365,11 @@ func (svc service) Unshare(ctx context.Context, token, id, relation string, user
func (svc service) changeClientStatus(ctx context.Context, token string, client mgclients.Client) (mgclients.Client, error) {
userID, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, auth.DeletePermission, auth.ThingType, client.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
dbClient, err := svc.clients.RetrieveByID(ctx, client.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
if dbClient.Status == client.Status {
return mgclients.Client{}, mgclients.ErrStatusAlreadyAssigned
@@ -385,14 +391,14 @@ func (svc service) ListClientsByGroup(ctx context.Context, token, groupID string
ObjectType: auth.ThingType,
})
if err != nil {
return mgclients.MembersPage{}, err
return mgclients.MembersPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
pm.IDs = tids.Policies
cp, err := svc.clients.RetrieveAllByIDs(ctx, pm)
if err != nil {
return mgclients.MembersPage{}, err
return mgclients.MembersPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return mgclients.MembersPage{
@@ -409,10 +415,10 @@ func (svc service) Identify(ctx context.Context, key string) (string, error) {
client, err := svc.clients.RetrieveBySecret(ctx, key)
if err != nil {
return "", err
return "", errors.Wrap(repoerr.ErrNotFound, err)
}
if err := svc.clientCache.Save(ctx, key, client.ID); err != nil {
return "", err
return "", errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return client.ID, nil
@@ -421,7 +427,7 @@ func (svc service) Identify(ctx context.Context, key string) (string, error) {
func (svc service) identify(ctx context.Context, token string) (*magistrala.IdentityRes, error) {
res, err := svc.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return nil, err
return nil, errors.Wrap(errors.ErrAuthentication, err)
}
if res.GetId() == "" || res.GetDomainId() == "" {
return nil, errors.ErrDomainAuthorization
@@ -443,7 +449,7 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per
return "", errors.Wrap(errors.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return "", errors.ErrAuthorization
return "", errors.Wrap(errors.ErrAuthorization, err)
}
return res.GetId(), nil
+10 -8
View File
@@ -16,6 +16,8 @@ import (
"github.com/absmach/magistrala/internal/testsutil"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
svcerror "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/absmach/magistrala/things"
"github.com/absmach/magistrala/things/mocks"
@@ -305,7 +307,7 @@ func TestViewClient(t *testing.T) {
for _, tc := range cases {
repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
if tc.token == authmocks.InvalidValue {
repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization)
repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, svcerror.ErrAuthorization)
}
repoCall1 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.response, tc.err)
rClient, err := svc.ViewClient(context.Background(), tc.token, tc.clientID)
@@ -576,8 +578,8 @@ func TestListClients(t *testing.T) {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: getIDs(tc.response.Clients)}, nil)
if tc.token == authmocks.InvalidValue {
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: authmocks.InvalidValue}).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication)
repoCall1 = auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{}, errors.ErrAuthorization)
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: authmocks.InvalidValue}).Return(&magistrala.IdentityRes{}, svcerror.ErrAuthentication)
repoCall1 = auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{}, svcerror.ErrAuthorization)
}
repoCall2 := cRepo.On("RetrieveAllByIDs", context.Background(), mock.Anything).Return(tc.response, tc.err)
page, err := svc.ListClients(context.Background(), tc.token, "", tc.page)
@@ -683,7 +685,7 @@ func TestUpdateClientTags(t *testing.T) {
client: client,
token: "non-existent",
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "update client name with invalid ID",
@@ -736,7 +738,7 @@ func TestUpdateClientOwner(t *testing.T) {
client: client,
token: "non-existent",
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "update client owner with invalid ID",
@@ -790,7 +792,7 @@ func TestUpdateClientSecret(t *testing.T) {
newSecret: "newPassword",
token: "non-existent",
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
}
@@ -845,7 +847,7 @@ func TestEnableClient(t *testing.T) {
token: validToken,
client: mgclients.Client{},
response: mgclients.Client{},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
}
@@ -966,7 +968,7 @@ func TestDisableClient(t *testing.T) {
client: mgclients.Client{},
token: validToken,
response: mgclients.Client{},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
}
+3 -3
View File
@@ -7,7 +7,7 @@ import (
"context"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors"
"google.golang.org/grpc"
)
@@ -40,7 +40,7 @@ func (repo singleUserRepo) Issue(ctx context.Context, in *magistrala.IssueReq, o
func (repo singleUserRepo) Identify(ctx context.Context, in *magistrala.IdentityReq, opts ...grpc.CallOption) (*magistrala.IdentityRes, error) {
if repo.token != in.GetToken() {
return nil, errors.ErrAuthentication
return nil, svcerr.ErrAuthentication
}
return &magistrala.IdentityRes{Id: repo.id}, nil
@@ -48,7 +48,7 @@ func (repo singleUserRepo) Identify(ctx context.Context, in *magistrala.Identity
func (repo singleUserRepo) Authorize(ctx context.Context, in *magistrala.AuthorizeReq, opts ...grpc.CallOption) (*magistrala.AuthorizeRes, error) {
if repo.id != in.Subject {
return &magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization
return &magistrala.AuthorizeRes{Authorized: false}, svcerr.ErrAuthorization
}
return &magistrala.AuthorizeRes{Authorized: true}, nil
+5 -4
View File
@@ -12,6 +12,7 @@ import (
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/absmach/magistrala/twins"
"github.com/absmach/magistrala/twins/mocks"
@@ -75,7 +76,7 @@ func TestTwinsSave(t *testing.T) {
Owner: email,
Name: invalidName,
},
err: errors.ErrMalformedEntity,
err: repoerror.ErrMalformedEntity,
},
}
@@ -123,7 +124,7 @@ func TestTwinsUpdate(t *testing.T) {
twin: twins.Twin{
ID: nonexistentTwinID,
},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
{
desc: "update twin with invalid name",
@@ -132,7 +133,7 @@ func TestTwinsUpdate(t *testing.T) {
Owner: email,
Name: invalidName,
},
err: errors.ErrMalformedEntity,
err: repoerror.ErrMalformedEntity,
},
}
@@ -176,7 +177,7 @@ func TestTwinsRetrieveByID(t *testing.T) {
{
desc: "retrieve a non-existing twin",
id: nonexistentTwinID,
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
}
+10 -8
View File
@@ -13,6 +13,8 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/mainflux/senml"
)
@@ -104,12 +106,12 @@ func (ts *twinsService) AddTwin(ctx context.Context, token string, twin Twin, de
defer ts.publish(ctx, &id, &err, crudOp["createSucc"], crudOp["createFail"], &b)
res, err := ts.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Twin{}, err
return Twin{}, errors.Wrap(errors.ErrAuthentication, err)
}
twin.ID, err = ts.idProvider.ID()
if err != nil {
return Twin{}, err
return Twin{}, errors.Wrap(svcerr.ErrUniqueID, err)
}
twin.Owner = res.GetId()
@@ -131,7 +133,7 @@ func (ts *twinsService) AddTwin(ctx context.Context, token string, twin Twin, de
twin.Revision = 0
if _, err = ts.twins.Save(ctx, twin); err != nil {
return Twin{}, err
return Twin{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
id = twin.ID
@@ -152,7 +154,7 @@ func (ts *twinsService) UpdateTwin(ctx context.Context, token string, twin Twin,
tw, err := ts.twins.RetrieveByID(ctx, twin.ID)
if err != nil {
return err
return errors.Wrap(repoerr.ErrNotFound, err)
}
revision := false
@@ -182,7 +184,7 @@ func (ts *twinsService) UpdateTwin(ctx context.Context, token string, twin Twin,
tw.Revision++
if err := ts.twins.Update(ctx, tw); err != nil {
return err
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
id = twin.ID
@@ -197,12 +199,12 @@ func (ts *twinsService) ViewTwin(ctx context.Context, token, twinID string) (tw
_, err = ts.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return Twin{}, err
return Twin{}, errors.Wrap(errors.ErrAuthorization, err)
}
twin, err := ts.twins.RetrieveByID(ctx, twinID)
if err != nil {
return Twin{}, err
return Twin{}, errors.Wrap(repoerr.ErrNotFound, err)
}
b, err = json.Marshal(twin)
@@ -220,7 +222,7 @@ func (ts *twinsService) RemoveTwin(ctx context.Context, token, twinID string) (e
}
if err := ts.twins.Remove(ctx, twinID); err != nil {
return err
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return ts.twinCache.Remove(ctx, twinID)
+16 -15
View File
@@ -8,6 +8,7 @@ import (
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/users/postgres"
"github.com/stretchr/testify/mock"
)
@@ -24,11 +25,11 @@ func (m *Repository) ChangeStatus(ctx context.Context, client mgclients.Client)
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Status != mgclients.EnabledStatus && client.Status != mgclients.DisabledStatus {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -37,7 +38,7 @@ func (m *Repository) ChangeStatus(ctx context.Context, client mgclients.Client)
func (m *Repository) Members(ctx context.Context, groupID string, pm mgclients.Page) (mgclients.MembersPage, error) {
ret := m.Called(ctx, groupID, pm)
if groupID == WrongID {
return mgclients.MembersPage{}, errors.ErrNotFound
return mgclients.MembersPage{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.MembersPage), ret.Error(1)
@@ -59,7 +60,7 @@ func (m *Repository) RetrieveByID(ctx context.Context, id string) (mgclients.Cli
ret := m.Called(ctx, id)
if id == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -69,7 +70,7 @@ func (m *Repository) RetrieveByIdentity(ctx context.Context, identity string) (m
ret := m.Called(ctx, identity)
if identity == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -78,10 +79,10 @@ func (m *Repository) RetrieveByIdentity(ctx context.Context, identity string) (m
func (m *Repository) Save(ctx context.Context, client mgclients.Client) (mgclients.Client, error) {
ret := m.Called(ctx, client)
if client.Owner == WrongID {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
if client.Credentials.Secret == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return client, ret.Error(1)
@@ -91,7 +92,7 @@ func (m *Repository) Update(ctx context.Context, client mgclients.Client) (mgcli
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
}
@@ -100,10 +101,10 @@ func (m *Repository) UpdateIdentity(ctx context.Context, client mgclients.Client
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Credentials.Identity == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -113,10 +114,10 @@ func (m *Repository) UpdateSecret(ctx context.Context, client mgclients.Client)
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
if client.Credentials.Secret == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -126,7 +127,7 @@ func (m *Repository) UpdateTags(ctx context.Context, client mgclients.Client) (m
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -136,7 +137,7 @@ func (m *Repository) UpdateOwner(ctx context.Context, client mgclients.Client) (
ret := m.Called(ctx, client)
if client.ID == WrongID {
return mgclients.Client{}, errors.ErrNotFound
return mgclients.Client{}, repoerr.ErrNotFound
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
@@ -146,7 +147,7 @@ func (m *Repository) RetrieveBySecret(ctx context.Context, key string) (mgclient
ret := m.Called(ctx, key)
if key == "" {
return mgclients.Client{}, errors.ErrMalformedEntity
return mgclients.Client{}, repoerr.ErrMalformedEntity
}
return ret.Get(0).(mgclients.Client), ret.Error(1)
+5 -4
View File
@@ -11,6 +11,7 @@ import (
mgclients "github.com/absmach/magistrala/pkg/clients"
pgclients "github.com/absmach/magistrala/pkg/clients/postgres"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
)
var _ mgclients.Repository = (*clientRepo)(nil)
@@ -43,24 +44,24 @@ func (repo clientRepo) Save(ctx context.Context, c mgclients.Client) (mgclients.
RETURNING id, name, tags, identity, metadata, COALESCE(owner_id, '') AS owner_id, status, created_at`
dbc, err := pgclients.ToDBClient(c)
if err != nil {
return mgclients.Client{}, errors.Wrap(errors.ErrCreateEntity, err)
return mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
row, err := repo.ClientRepository.DB.NamedQueryContext(ctx, q, dbc)
if err != nil {
return mgclients.Client{}, postgres.HandleError(err, errors.ErrCreateEntity)
return mgclients.Client{}, postgres.HandleError(repoerr.ErrCreateEntity, err)
}
defer row.Close()
row.Next()
dbc = pgclients.DBClient{}
if err := row.StructScan(&dbc); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
client, err := pgclients.ToClient(dbc)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
return client, nil
+65 -49
View File
@@ -11,9 +11,10 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/internal/apiutil"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/users/postgres"
)
@@ -23,6 +24,21 @@ var (
// ErrPasswordFormat indicates weak password.
ErrPasswordFormat = errors.New("password does not meet the requirements")
// ErrFailedPolicyUpdate indicates a failure to update user policy.
ErrFailedPolicyUpdate = errors.New("failed to update user policy")
// ErrFailedOwnerUpdate indicates a failure to update user policy.
ErrFailedOwnerUpdate = errors.New("failed to update user owner")
// ErrAddPolicies indictaed a failre to add policies.
errAddPolicies = errors.New("failed to add policies")
// ErrIssueToken indicates a failure to issue token.
ErrIssueToken = errors.New("failed to issue token")
// ErrAddPolicies indictaed a failre to add policies.
errDeletePolicies = errors.New("failed to delete policies")
)
type service struct {
@@ -52,31 +68,31 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
if !svc.selfRegister {
userID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.checkSuperAdmin(ctx, userID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
clientID, err := svc.idProvider.ID()
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrUniqueID, err)
}
if cli.Credentials.Secret == "" {
return mgclients.Client{}, apiutil.ErrMissingSecret
return mgclients.Client{}, errors.Wrap(repoerr.ErrMalformedEntity, repoerr.ErrMissingSecret)
}
hash, err := svc.hasher.Hash(cli.Credentials.Secret)
if err != nil {
return mgclients.Client{}, errors.Wrap(errors.ErrMalformedEntity, err)
return mgclients.Client{}, errors.Wrap(repoerr.ErrMalformedEntity, err)
}
cli.Credentials.Secret = hash
if cli.Status != mgclients.DisabledStatus && cli.Status != mgclients.EnabledStatus {
return mgclients.Client{}, apiutil.ErrInvalidStatus
return mgclients.Client{}, svcerr.ErrInvalidStatus
}
if cli.Role != mgclients.UserRole && cli.Role != mgclients.AdminRole {
return mgclients.Client{}, apiutil.ErrInvalidRole
return mgclients.Client{}, svcerr.ErrInvalidRole
}
cli.ID = clientID
cli.CreatedAt = time.Now()
@@ -89,7 +105,7 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
ObjectType: auth.PlatformType,
})
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrCreateEntity, err)
}
if !res.Authorized {
return mgclients.Client{}, fmt.Errorf("failed to create policy")
@@ -103,7 +119,7 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
Object: auth.MagistralaObject,
ObjectType: auth.PlatformType,
}); errRollback != nil {
err = errors.Wrap(err, errors.Wrap(apiutil.ErrRollbackTx, errRollback))
err = errors.Wrap(err, errors.Wrap(repoerr.ErrRollbackTx, errRollback))
}
}
}()
@@ -114,7 +130,7 @@ func (svc service) RegisterClient(ctx context.Context, token string, cli mgclien
func (svc service) IssueToken(ctx context.Context, identity, secret, domainID string) (*magistrala.Token, error) {
dbUser, err := svc.clients.RetrieveByIdentity(ctx, identity)
if err != nil {
return &magistrala.Token{}, err
return &magistrala.Token{}, errors.Wrap(repoerr.ErrNotFound, err)
}
if err := svc.hasher.Compare(secret, dbUser.Credentials.Secret); err != nil {
return &magistrala.Token{}, errors.Wrap(errors.ErrLogin, err)
@@ -138,18 +154,18 @@ func (svc service) RefreshToken(ctx context.Context, refreshToken, domainID stri
func (svc service) ViewClient(ctx context.Context, token string, id string) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if tokenUserID != id {
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
client, err := svc.clients.RetrieveByID(ctx, id)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
client.Credentials.Secret = ""
@@ -159,11 +175,11 @@ func (svc service) ViewClient(ctx context.Context, token string, id string) (mgc
func (svc service) ViewProfile(ctx context.Context, token string) (mgclients.Client, error) {
id, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
client, err := svc.clients.RetrieveByID(ctx, id)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
client.Credentials.Secret = ""
@@ -173,7 +189,7 @@ func (svc service) ViewProfile(ctx context.Context, token string) (mgclients.Cli
func (svc service) ListClients(ctx context.Context, token string, pm mgclients.Page) (mgclients.ClientsPage, error) {
userID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.ClientsPage{}, err
return mgclients.ClientsPage{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.checkSuperAdmin(ctx, userID); err == nil {
return svc.clients.RetrieveAll(ctx, pm)
@@ -193,12 +209,12 @@ func (svc service) ListClients(ctx context.Context, token string, pm mgclients.P
func (svc service) UpdateClient(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if tokenUserID != cli.ID {
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
@@ -216,12 +232,12 @@ func (svc service) UpdateClient(ctx context.Context, token string, cli mgclients
func (svc service) UpdateClientTags(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if tokenUserID != cli.ID {
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
@@ -238,12 +254,12 @@ func (svc service) UpdateClientTags(ctx context.Context, token string, cli mgcli
func (svc service) UpdateClientIdentity(ctx context.Context, token, clientID, identity string) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if tokenUserID != clientID {
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
}
@@ -278,11 +294,11 @@ func (svc service) GenerateResetToken(ctx context.Context, email, host string) e
func (svc service) ResetSecret(ctx context.Context, resetToken, secret string) error {
id, err := svc.Identify(ctx, resetToken)
if err != nil {
return errors.Wrap(errors.ErrAuthentication, err)
return errors.Wrap(svcerr.ErrAuthentication, err)
}
c, err := svc.clients.RetrieveByID(ctx, id)
if err != nil {
return err
return errors.Wrap(repoerr.ErrNotFound, err)
}
if c.Credentials.Identity == "" {
return errors.ErrNotFound
@@ -303,7 +319,7 @@ func (svc service) ResetSecret(ctx context.Context, resetToken, secret string) e
UpdatedBy: id,
}
if _, err := svc.clients.UpdateSecret(ctx, c); err != nil {
return err
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
@@ -311,21 +327,21 @@ func (svc service) ResetSecret(ctx context.Context, resetToken, secret string) e
func (svc service) UpdateClientSecret(ctx context.Context, token, oldSecret, newSecret string) (mgclients.Client, error) {
id, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if !svc.passRegex.MatchString(newSecret) {
return mgclients.Client{}, ErrPasswordFormat
}
dbClient, err := svc.clients.RetrieveByID(ctx, id)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
if _, err := svc.IssueToken(ctx, dbClient.Credentials.Identity, oldSecret, ""); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(ErrIssueToken, err)
}
newSecret, err = svc.hasher.Hash(newSecret)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrMalformedEntity, err)
}
dbClient.Credentials.Secret = newSecret
dbClient.UpdatedAt = time.Now()
@@ -342,11 +358,11 @@ func (svc service) SendPasswordReset(_ context.Context, host, email, user, token
func (svc service) UpdateClientRole(ctx context.Context, token string, cli mgclients.Client) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
client := mgclients.Client{
ID: cli.ID,
@@ -356,15 +372,15 @@ func (svc service) UpdateClientRole(ctx context.Context, token string, cli mgcli
}
if err := svc.updateClientPolicy(ctx, cli.ID, cli.Role); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(ErrFailedPolicyUpdate, err)
}
client, err = svc.clients.UpdateOwner(ctx, client)
if err != nil {
// If failed to update role in DB, then revert back to platform admin policy in spicedb
if errRollback := svc.updateClientPolicy(ctx, cli.ID, mgclients.UserRole); errRollback != nil {
return mgclients.Client{}, errors.Wrap(err, errors.Wrap(apiutil.ErrRollbackTx, errRollback))
return mgclients.Client{}, errors.Wrap(err, errors.Wrap(repoerr.ErrRollbackTx, errRollback))
}
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(ErrFailedOwnerUpdate, err)
}
return client, nil
}
@@ -400,14 +416,14 @@ func (svc service) DisableClient(ctx context.Context, token, id string) (mgclien
func (svc service) changeClientStatus(ctx context.Context, token string, client mgclients.Client) (mgclients.Client, error) {
tokenUserID, err := svc.Identify(ctx, token)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
if err := svc.checkSuperAdmin(ctx, tokenUserID); err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(svcerr.ErrAuthentication, err)
}
dbClient, err := svc.clients.RetrieveByID(ctx, client.ID)
if err != nil {
return mgclients.Client{}, err
return mgclients.Client{}, errors.Wrap(repoerr.ErrNotFound, err)
}
if dbClient.Status == client.Status {
return mgclients.Client{}, mgclients.ErrStatusAlreadyAssigned
@@ -434,7 +450,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind string, ob
}
if _, err := svc.authorize(ctx, auth.UserType, auth.TokenKind, token, authzPerm, objectType, objectID); err != nil {
return mgclients.MembersPage{}, err
return mgclients.MembersPage{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
duids, err := svc.auth.ListAllSubjects(ctx, &magistrala.ListSubjectsReq{
SubjectType: auth.UserType,
@@ -443,7 +459,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind string, ob
ObjectType: objectType,
})
if err != nil {
return mgclients.MembersPage{}, err
return mgclients.MembersPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
if len(duids.Policies) == 0 {
return mgclients.MembersPage{
@@ -461,7 +477,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind string, ob
cp, err := svc.clients.RetrieveAll(ctx, pm)
if err != nil {
return mgclients.MembersPage{}, err
return mgclients.MembersPage{}, errors.Wrap(repoerr.ErrNotFound, err)
}
return mgclients.MembersPage{
@@ -473,7 +489,7 @@ func (svc service) ListMembers(ctx context.Context, token, objectKind string, ob
func (svc *service) checkSuperAdmin(ctx context.Context, adminID string) error {
if _, err := svc.authorize(ctx, auth.UserType, auth.UsersKind, adminID, auth.AdminPermission, auth.PlatformType, auth.MagistralaObject); err != nil {
if err := svc.clients.CheckSuperAdmin(ctx, adminID); err != nil {
return err
return errors.Wrap(svcerr.ErrAuthorization, err)
}
}
@@ -491,11 +507,11 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per
}
res, err := svc.auth.Authorize(ctx, req)
if err != nil {
return "", errors.Wrap(errors.ErrAuthorization, err)
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return "", errors.ErrAuthorization
return "", errors.Wrap(svcerr.ErrAuthorization, err)
}
return res.GetId(), nil
}
@@ -503,7 +519,7 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per
func (svc service) Identify(ctx context.Context, token string) (string, error) {
user, err := svc.auth.Identify(ctx, &magistrala.IdentityReq{Token: token})
if err != nil {
return "", err
return "", errors.Wrap(svcerr.ErrAuthentication, err)
}
return user.GetUserId(), nil
}
@@ -519,10 +535,10 @@ func (svc service) updateClientPolicy(ctx context.Context, userID string, role m
Object: auth.MagistralaObject,
})
if err != nil {
return err
return errors.Wrap(errAddPolicies, err)
}
if !resp.Authorized {
return errors.ErrAuthorization
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
case mgclients.UserRole:
@@ -536,10 +552,10 @@ func (svc service) updateClientPolicy(ctx context.Context, userID string, role m
Object: auth.MagistralaObject,
})
if err != nil {
return err
return errors.Wrap(errDeletePolicies, err)
}
if !resp.Deleted {
return errors.ErrAuthorization
return errors.Wrap(errDeletePolicies, err)
}
return nil
}
+15 -13
View File
@@ -16,6 +16,8 @@ import (
"github.com/absmach/magistrala/internal/testsutil"
mgclients "github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerror "github.com/absmach/magistrala/pkg/errors/repository"
svcerror "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/absmach/magistrala/users"
"github.com/absmach/magistrala/users/hasher"
@@ -238,7 +240,7 @@ func TestRegisterClient(t *testing.T) {
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil)
if tc.token == inValidToken {
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication)
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, svcerror.ErrAuthentication)
}
repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil)
repoCall2 := auth.On("DeletePolicy", mock.Anything, mock.Anything).Return(&magistrala.DeletePolicyRes{Deleted: true}, nil)
@@ -600,7 +602,7 @@ func TestListClients(t *testing.T) {
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{UserId: validID}, nil)
if tc.token == inValidToken {
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication)
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, svcerror.ErrAuthentication)
}
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := cRepo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.response, tc.err)
@@ -723,7 +725,7 @@ func TestUpdateClientTags(t *testing.T) {
client: client,
token: inValidToken,
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "update client name with invalid ID",
@@ -789,7 +791,7 @@ func TestUpdateClientIdentity(t *testing.T) {
token: validToken,
id: mocks.WrongID,
response: mgclients.Client{},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
{
desc: "update client identity with invalid token",
@@ -797,7 +799,7 @@ func TestUpdateClientIdentity(t *testing.T) {
token: inValidToken,
id: client2.ID,
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
}
@@ -849,7 +851,7 @@ func TestUpdateClientOwner(t *testing.T) {
client: client,
token: inValidToken,
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "update client owner with invalid ID",
@@ -919,7 +921,7 @@ func TestUpdateClientSecret(t *testing.T) {
newSecret: "newPassword",
token: inValidToken,
response: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "update client secret with wrong old secret",
@@ -934,7 +936,7 @@ func TestUpdateClientSecret(t *testing.T) {
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{UserId: client.ID}, nil)
if tc.token == inValidToken {
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication)
repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: inValidToken}).Return(&magistrala.IdentityRes{}, svcerror.ErrAuthentication)
}
repoCall1 := cRepo.On("RetrieveByID", context.Background(), client.ID).Return(tc.response, tc.err)
repoCall2 := cRepo.On("RetrieveByIdentity", context.Background(), client.Credentials.Identity).Return(tc.response, tc.err)
@@ -1000,7 +1002,7 @@ func TestEnableClient(t *testing.T) {
token: validToken,
client: mgclients.Client{},
response: mgclients.Client{},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
}
@@ -1130,7 +1132,7 @@ func TestDisableClient(t *testing.T) {
client: mgclients.Client{},
token: validToken,
response: mgclients.Client{},
err: errors.ErrNotFound,
err: repoerror.ErrNotFound,
},
}
@@ -1369,7 +1371,7 @@ func TestIssueToken(t *testing.T) {
desc: "issue token for a non-existing client",
client: client,
rClient: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "issue token for a client with wrong secret",
@@ -1420,7 +1422,7 @@ func TestRefreshToken(t *testing.T) {
desc: "refresh token with refresh token for a non-existing client",
token: validToken,
client: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "refresh token with access token for an existing client",
@@ -1432,7 +1434,7 @@ func TestRefreshToken(t *testing.T) {
desc: "refresh token with access token for a non-existing client",
token: validToken,
client: mgclients.Client{},
err: errors.ErrAuthentication,
err: svcerror.ErrAuthentication,
},
{
desc: "refresh token with invalid token for an existing client",
+3 -2
View File
@@ -14,6 +14,7 @@ import (
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/mainflux/mproxy/pkg/session"
)
@@ -245,10 +246,10 @@ func (h *handler) authAccess(ctx context.Context, password, topic, action string
}
res, err := h.auth.Authorize(ctx, ar)
if err != nil {
return err
return errors.Wrap(svcerr.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return errors.ErrAuthorization
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil