MG-2137 - Generate mocks with mockery for Certs service (#2138)

Signed-off-by: JeffMboya <jangina.mboya@gmail.com>
This commit is contained in:
JMboya
2024-04-23 10:01:01 +03:00
committed by GitHub
parent cf832af14b
commit 3aa674fca0
10 changed files with 828 additions and 622 deletions
@@ -69,6 +69,9 @@ jobs:
- "consumers/notifiers/notifier.go"
- "consumers/notifiers/service.go"
- "consumers/notifiers/subscriptions.go"
- "certs/certs.go"
- "certs/pki/vault.go"
- "certs/service.go"
- name: Set up protoc
if: steps.changes.outputs.proto == 'true'
@@ -144,6 +147,9 @@ jobs:
mv ./consumers/notifiers/mocks/notifier.go ./consumers/notifiers/mocks/notifier.go.tmp
mv ./consumers/notifiers/mocks/service.go ./consumers/notifiers/mocks/service.go.tmp
mv ./consumers/notifiers/mocks/repository.go ./consumers/notifiers/mocks/repository.go.tmp
mv ./certs/mocks/certs.go ./certs/mocks/certs.go.tmp
mv ./certs/mocks/pki.go ./certs/mocks/pki.go.tmp
mv ./certs/mocks/service.go ./certs/mocks/service.go.tmp
make mocks
@@ -188,3 +194,6 @@ jobs:
check_mock_changes ./consumers/notifiers/mocks/notifier.go "Notifiers Notifier ./consumers/notifiers/mocks/notifier.go"
check_mock_changes ./consumers/notifiers/mocks/service.go "Notifiers Service ./consumers/notifiers/mocks/service.go"
check_mock_changes ./consumers/notifiers/mocks/repository.go "Notifiers Repository ./consumers/notifiers/mocks/repository.go"
check_mock_changes ./certs/mocks/certs.go "Certs Repository ./certs/mocks/certs.go"
check_mock_changes ./certs/mocks/pki.go "PKI ./certs/mocks/pki.go"
check_mock_changes ./certs/mocks/service.go "Certs Service ./certs/mocks/service.go"
+2
View File
@@ -24,6 +24,8 @@ type Page struct {
var ErrMissingCerts = errors.New("CA path or CA key path not set")
// Repository specifies a Config persistence API.
//
//go:generate mockery --name Repository --output=./mocks --filename certs.go --quiet --note "Copyright (c) Abstract Machines"
type Repository interface {
// Save saves cert for thing into database
Save(ctx context.Context, cert Cert) (string, error)
+120 -95
View File
@@ -1,137 +1,162 @@
// Code generated by mockery v2.42.1. DO NOT EDIT.
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
"context"
"sync"
context "context"
"github.com/absmach/magistrala/certs"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
certs "github.com/absmach/magistrala/certs"
mock "github.com/stretchr/testify/mock"
)
var _ certs.Repository = (*certsRepoMock)(nil)
type certsRepoMock struct {
mu sync.Mutex
counter uint64
certsBySerial map[string]certs.Cert
certsByThingID map[string]map[string][]certs.Cert
// Repository is an autogenerated mock type for the Repository type
type Repository struct {
mock.Mock
}
// NewCertsRepository creates in-memory certs repository.
func NewCertsRepository() certs.Repository {
return &certsRepoMock{
certsBySerial: make(map[string]certs.Cert),
certsByThingID: make(map[string]map[string][]certs.Cert),
// Remove provides a mock function with given fields: ctx, ownerID, thingID
func (_m *Repository) Remove(ctx context.Context, ownerID string, thingID string) error {
ret := _m.Called(ctx, ownerID, thingID)
if len(ret) == 0 {
panic("no return value specified for Remove")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, ownerID, thingID)
} else {
r0 = ret.Error(0)
}
return r0
}
func (c *certsRepoMock) Save(ctx context.Context, cert certs.Cert) (string, error) {
c.mu.Lock()
defer c.mu.Unlock()
// RetrieveAll provides a mock function with given fields: ctx, ownerID, offset, limit
func (_m *Repository) RetrieveAll(ctx context.Context, ownerID string, offset uint64, limit uint64) (certs.Page, error) {
ret := _m.Called(ctx, ownerID, offset, limit)
crt := certs.Cert{
OwnerID: cert.OwnerID,
ThingID: cert.ThingID,
Serial: cert.Serial,
Expire: cert.Expire,
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
_, ok := c.certsByThingID[cert.OwnerID][cert.ThingID]
switch ok {
case false:
c.certsByThingID[cert.OwnerID] = map[string][]certs.Cert{
cert.ThingID: {crt},
}
default:
c.certsByThingID[cert.OwnerID][cert.ThingID] = append(c.certsByThingID[cert.OwnerID][cert.ThingID], crt)
var r0 certs.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (certs.Page, error)); ok {
return rf(ctx, ownerID, offset, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) certs.Page); ok {
r0 = rf(ctx, ownerID, offset, limit)
} else {
r0 = ret.Get(0).(certs.Page)
}
c.certsBySerial[cert.Serial] = crt
c.counter++
return cert.Serial, nil
if rf, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok {
r1 = rf(ctx, ownerID, offset, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (c *certsRepoMock) RetrieveAll(ctx context.Context, ownerID string, offset, limit uint64) (certs.Page, error) {
c.mu.Lock()
defer c.mu.Unlock()
if limit <= 0 {
return certs.Page{}, nil
// RetrieveBySerial provides a mock function with given fields: ctx, ownerID, serialID
func (_m *Repository) RetrieveBySerial(ctx context.Context, ownerID string, serialID string) (certs.Cert, error) {
ret := _m.Called(ctx, ownerID, serialID)
if len(ret) == 0 {
panic("no return value specified for RetrieveBySerial")
}
oc, ok := c.certsByThingID[ownerID]
if !ok {
return certs.Page{}, repoerr.ErrNotFound
var r0 certs.Cert
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Cert, error)); ok {
return rf(ctx, ownerID, serialID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Cert); ok {
r0 = rf(ctx, ownerID, serialID)
} else {
r0 = ret.Get(0).(certs.Cert)
}
var crts []certs.Cert
for _, tc := range oc {
for i, v := range tc {
if uint64(i) >= offset && uint64(i) < offset+limit {
crts = append(crts, v)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, ownerID, serialID)
} else {
r1 = ret.Error(1)
}
page := certs.Page{
Certs: crts,
Total: c.counter,
Offset: offset,
Limit: limit,
}
return page, nil
return r0, r1
}
func (c *certsRepoMock) Remove(ctx context.Context, ownerID, serial string) error {
c.mu.Lock()
defer c.mu.Unlock()
crt, ok := c.certsBySerial[serial]
if !ok {
return repoerr.ErrNotFound
// RetrieveByThing provides a mock function with given fields: ctx, ownerID, thingID, offset, limit
func (_m *Repository) RetrieveByThing(ctx context.Context, ownerID string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
ret := _m.Called(ctx, ownerID, thingID, offset, limit)
if len(ret) == 0 {
panic("no return value specified for RetrieveByThing")
}
delete(c.certsBySerial, crt.Serial)
delete(c.certsByThingID, crt.ThingID)
return nil
var r0 certs.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
return rf(ctx, ownerID, thingID, offset, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
r0 = rf(ctx, ownerID, thingID, offset, limit)
} else {
r0 = ret.Get(0).(certs.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
r1 = rf(ctx, ownerID, thingID, offset, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (c *certsRepoMock) RetrieveByThing(ctx context.Context, ownerID, thingID string, offset, limit uint64) (certs.Page, error) {
c.mu.Lock()
defer c.mu.Unlock()
if limit <= 0 {
return certs.Page{}, nil
// Save provides a mock function with given fields: ctx, cert
func (_m *Repository) Save(ctx context.Context, cert certs.Cert) (string, error) {
ret := _m.Called(ctx, cert)
if len(ret) == 0 {
panic("no return value specified for Save")
}
cs, ok := c.certsByThingID[ownerID][thingID]
if !ok {
return certs.Page{}, repoerr.ErrNotFound
var r0 string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, certs.Cert) (string, error)); ok {
return rf(ctx, cert)
}
if rf, ok := ret.Get(0).(func(context.Context, certs.Cert) string); ok {
r0 = rf(ctx, cert)
} else {
r0 = ret.Get(0).(string)
}
var crts []certs.Cert
for i, v := range cs {
if uint64(i) >= offset && uint64(i) < offset+limit {
crts = append(crts, v)
}
if rf, ok := ret.Get(1).(func(context.Context, certs.Cert) error); ok {
r1 = rf(ctx, cert)
} else {
r1 = ret.Error(1)
}
page := certs.Page{
Certs: crts,
Total: c.counter,
Offset: offset,
Limit: limit,
}
return page, nil
return r0, r1
}
func (c *certsRepoMock) RetrieveBySerial(ctx context.Context, ownerID, serialID string) (certs.Cert, error) {
c.mu.Lock()
defer c.mu.Unlock()
// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewRepository(t interface {
mock.TestingT
Cleanup(func())
}) *Repository {
mock := &Repository{}
mock.Mock.Test(t)
crt, ok := c.certsBySerial[serialID]
if !ok {
return certs.Cert{}, repoerr.ErrNotFound
}
t.Cleanup(func() { mock.AssertExpectations(t) })
return crt, nil
return mock
}
+106 -167
View File
@@ -1,196 +1,135 @@
// Code generated by mockery v2.42.1. DO NOT EDIT.
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"sync"
"time"
context "context"
"github.com/absmach/magistrala/certs/pki"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
pki "github.com/absmach/magistrala/certs/pki"
mock "github.com/stretchr/testify/mock"
time "time"
)
const keyBits = 2048
var (
errPrivateKeyEmpty = errors.New("private key is empty")
errPrivateKeyUnsupportedType = errors.New("private key type is unsupported")
)
var _ pki.Agent = (*agent)(nil)
type agent struct {
AuthTimeout time.Duration
TLSCert tls.Certificate
X509Cert *x509.Certificate
TTL string
mu sync.Mutex
counter uint64
certs map[string]pki.Cert
// Agent is an autogenerated mock type for the Agent type
type Agent struct {
mock.Mock
}
func NewPkiAgent(tlsCert tls.Certificate, caCert *x509.Certificate, ttl string, timeout time.Duration) pki.Agent {
return &agent{
AuthTimeout: timeout,
TLSCert: tlsCert,
X509Cert: caCert,
TTL: ttl,
certs: make(map[string]pki.Cert),
// IssueCert provides a mock function with given fields: cn, ttl
func (_m *Agent) IssueCert(cn string, ttl string) (pki.Cert, error) {
ret := _m.Called(cn, ttl)
if len(ret) == 0 {
panic("no return value specified for IssueCert")
}
var r0 pki.Cert
var r1 error
if rf, ok := ret.Get(0).(func(string, string) (pki.Cert, error)); ok {
return rf(cn, ttl)
}
if rf, ok := ret.Get(0).(func(string, string) pki.Cert); ok {
r0 = rf(cn, ttl)
} else {
r0 = ret.Get(0).(pki.Cert)
}
if rf, ok := ret.Get(1).(func(string, string) error); ok {
r1 = rf(cn, ttl)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func (a *agent) IssueCert(cn, ttl string) (pki.Cert, error) {
a.mu.Lock()
defer a.mu.Unlock()
// LoginAndRenew provides a mock function with given fields: ctx
func (_m *Agent) LoginAndRenew(ctx context.Context) error {
ret := _m.Called(ctx)
if a.X509Cert == nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, pki.ErrMissingCACertificate)
if len(ret) == 0 {
panic("no return value specified for LoginAndRenew")
}
var priv interface{}
priv, err := rsa.GenerateKey(rand.Reader, keyBits)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
r0 = rf(ctx)
} else {
r0 = ret.Error(0)
}
if ttl == "" {
ttl = a.TTL
}
notBefore := time.Now()
validFor, err := time.ParseDuration(ttl)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
notAfter := notBefore.Add(validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
tmpl := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Magistrala"},
CommonName: cn,
OrganizationalUnit: []string{"magistrala"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
SubjectKeyId: []byte{1, 2, 3, 4, 6},
}
pubKey, err := publicKey(priv)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, a.X509Cert, pubKey, a.TLSCert.PrivateKey)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
x509cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
var bw, keyOut bytes.Buffer
buffWriter := bufio.NewWriter(&bw)
buffKeyOut := bufio.NewWriter(&keyOut)
if err := pem.Encode(buffWriter, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
buffWriter.Flush()
cert := bw.String()
block, err := pemBlockForKey(priv)
if err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
if err := pem.Encode(buffKeyOut, block); err != nil {
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
}
buffKeyOut.Flush()
key := keyOut.String()
a.certs[x509cert.SerialNumber.String()] = pki.Cert{
ClientCert: cert,
}
a.counter++
return pki.Cert{
ClientCert: cert,
ClientKey: key,
Serial: x509cert.SerialNumber.String(),
Expire: x509cert.NotAfter.Unix(),
IssuingCA: x509cert.Issuer.String(),
}, nil
return r0
}
func (a *agent) Read(serial string) (pki.Cert, error) {
a.mu.Lock()
defer a.mu.Unlock()
// Read provides a mock function with given fields: serial
func (_m *Agent) Read(serial string) (pki.Cert, error) {
ret := _m.Called(serial)
crt, ok := a.certs[serial]
if !ok {
return pki.Cert{}, repoerr.ErrNotFound
if len(ret) == 0 {
panic("no return value specified for Read")
}
return crt, nil
}
func (a *agent) Revoke(serial string) (time.Time, error) {
return time.Now(), nil
}
func (a *agent) LoginAndRenew(ctx context.Context) error {
return nil
}
func publicKey(priv interface{}) (interface{}, error) {
if priv == nil {
return nil, errPrivateKeyEmpty
var r0 pki.Cert
var r1 error
if rf, ok := ret.Get(0).(func(string) (pki.Cert, error)); ok {
return rf(serial)
}
switch k := priv.(type) {
case *rsa.PrivateKey:
return &k.PublicKey, nil
case *ecdsa.PrivateKey:
return &k.PublicKey, nil
default:
return nil, errPrivateKeyUnsupportedType
if rf, ok := ret.Get(0).(func(string) pki.Cert); ok {
r0 = rf(serial)
} else {
r0 = ret.Get(0).(pki.Cert)
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(serial)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
func pemBlockForKey(priv interface{}) (*pem.Block, error) {
switch k := priv.(type) {
case *rsa.PrivateKey:
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
case *ecdsa.PrivateKey:
b, err := x509.MarshalECPrivateKey(k)
if err != nil {
return nil, err
}
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, nil
default:
return nil, nil
// Revoke provides a mock function with given fields: serial
func (_m *Agent) Revoke(serial string) (time.Time, error) {
ret := _m.Called(serial)
if len(ret) == 0 {
panic("no return value specified for Revoke")
}
var r0 time.Time
var r1 error
if rf, ok := ret.Get(0).(func(string) (time.Time, error)); ok {
return rf(serial)
}
if rf, ok := ret.Get(0).(func(string) time.Time); ok {
r0 = rf(serial)
} else {
r0 = ret.Get(0).(time.Time)
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(serial)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewAgent creates a new instance of Agent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewAgent(t interface {
mock.TestingT
Cleanup(func())
}) *Agent {
mock := &Agent{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+172
View File
@@ -0,0 +1,172 @@
// Code generated by mockery v2.42.1. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
certs "github.com/absmach/magistrala/certs"
mock "github.com/stretchr/testify/mock"
)
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
// IssueCert provides a mock function with given fields: ctx, token, thingID, ttl
func (_m *Service) IssueCert(ctx context.Context, token string, thingID string, ttl string) (certs.Cert, error) {
ret := _m.Called(ctx, token, thingID, ttl)
if len(ret) == 0 {
panic("no return value specified for IssueCert")
}
var r0 certs.Cert
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (certs.Cert, error)); ok {
return rf(ctx, token, thingID, ttl)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) certs.Cert); ok {
r0 = rf(ctx, token, thingID, ttl)
} else {
r0 = ret.Get(0).(certs.Cert)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = rf(ctx, token, thingID, ttl)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListCerts provides a mock function with given fields: ctx, token, thingID, offset, limit
func (_m *Service) ListCerts(ctx context.Context, token string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
ret := _m.Called(ctx, token, thingID, offset, limit)
if len(ret) == 0 {
panic("no return value specified for ListCerts")
}
var r0 certs.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
return rf(ctx, token, thingID, offset, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
r0 = rf(ctx, token, thingID, offset, limit)
} else {
r0 = ret.Get(0).(certs.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
r1 = rf(ctx, token, thingID, offset, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListSerials provides a mock function with given fields: ctx, token, thingID, offset, limit
func (_m *Service) ListSerials(ctx context.Context, token string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
ret := _m.Called(ctx, token, thingID, offset, limit)
if len(ret) == 0 {
panic("no return value specified for ListSerials")
}
var r0 certs.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
return rf(ctx, token, thingID, offset, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
r0 = rf(ctx, token, thingID, offset, limit)
} else {
r0 = ret.Get(0).(certs.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
r1 = rf(ctx, token, thingID, offset, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RevokeCert provides a mock function with given fields: ctx, token, serialID
func (_m *Service) RevokeCert(ctx context.Context, token string, serialID string) (certs.Revoke, error) {
ret := _m.Called(ctx, token, serialID)
if len(ret) == 0 {
panic("no return value specified for RevokeCert")
}
var r0 certs.Revoke
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Revoke, error)); ok {
return rf(ctx, token, serialID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Revoke); ok {
r0 = rf(ctx, token, serialID)
} else {
r0 = ret.Get(0).(certs.Revoke)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, token, serialID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ViewCert provides a mock function with given fields: ctx, token, serialID
func (_m *Service) ViewCert(ctx context.Context, token string, serialID string) (certs.Cert, error) {
ret := _m.Called(ctx, token, serialID)
if len(ret) == 0 {
panic("no return value specified for ViewCert")
}
var r0 certs.Cert
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Cert, error)); ok {
return rf(ctx, token, serialID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Cert); ok {
r0 = rf(ctx, token, serialID)
} else {
r0 = ret.Get(0).(certs.Cert)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, token, serialID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+2
View File
@@ -54,6 +54,8 @@ type Cert struct {
}
// Agent represents the Vault PKI interface.
//
//go:generate mockery --name Agent --output=../mocks --filename pki.go --quiet --note "Copyright (c) Abstract Machines"
type Agent interface {
// IssueCert issues certificate on PKI
IssueCert(cn, ttl string) (Cert, error)
+2
View File
@@ -31,6 +31,8 @@ var _ Service = (*certsService)(nil)
// Service specifies an API that must be fulfilled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
//
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// IssueCert issues certificate for given thing id if access is granted with token
IssueCert(ctx context.Context, token, thingID, ttl string) (Cert, error)
+255 -175
View File
@@ -14,7 +14,9 @@ import (
authmocks "github.com/absmach/magistrala/auth/mocks"
"github.com/absmach/magistrala/certs"
"github.com/absmach/magistrala/certs/mocks"
"github.com/absmach/magistrala/certs/pki"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
sdkmocks "github.com/absmach/magistrala/pkg/sdk/mocks"
@@ -33,217 +35,271 @@ const (
ttl = "1h"
certNum = 10
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
cfgAuthTimeout = "1s"
caPath = "../docker/ssl/certs/ca.crt"
caKeyPath = "../docker/ssl/certs/ca.key"
cfgSignHoursValid = "24h"
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
)
func newService(t *testing.T) (certs.Service, *authmocks.AuthClient, *sdkmocks.SDK) {
func newService(_ *testing.T) (certs.Service, *mocks.Repository, *mocks.Agent, *authmocks.AuthClient, *sdkmocks.SDK) {
repo := new(mocks.Repository)
agent := new(mocks.Agent)
auth := new(authmocks.AuthClient)
sdk := new(sdkmocks.SDK)
repo := mocks.NewCertsRepository()
tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
require.Nil(t, err, fmt.Sprintf("unexpected cert loading error: %s\n", err))
return certs.New(auth, repo, sdk, agent), repo, agent, auth, sdk
}
authTimeout, err := time.ParseDuration(cfgAuthTimeout)
require.Nil(t, err, fmt.Sprintf("unexpected auth timeout parsing error: %s\n", err))
pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)
return certs.New(auth, repo, sdk, pki), auth, sdk
var cert = certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: "",
Expire: time.Time{},
}
func TestIssueCert(t *testing.T) {
svc, auth, sdk := newService(t)
svc, repo, agent, auth, sdk := newService(t)
cases := []struct {
token string
desc string
thingID string
ttl string
key string
err error
token string
desc string
thingID string
ttl string
key string
pki pki.Cert
identifyRes *magistrala.IdentityRes
identifyErr error
thingErr errors.SDKError
issueCertErr error
repoErr error
err error
}{
{
desc: "issue new cert",
token: token,
thingID: thingID,
ttl: ttl,
err: nil,
pki: pki.Cert{
ClientCert: "",
IssuingCA: "",
CAChain: []string{},
ClientKey: "",
PrivateKeyType: "",
Serial: "",
Expire: 0,
},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "issue new cert for non existing thing id",
token: token,
thingID: "2",
ttl: ttl,
err: certs.ErrFailedCertCreation,
pki: pki.Cert{
ClientCert: "",
IssuingCA: "",
CAChain: []string{},
ClientKey: "",
PrivateKeyType: "",
Serial: "",
Expire: 0,
},
identifyRes: &magistrala.IdentityRes{Id: validID},
thingErr: errors.NewSDKError(errors.ErrMalformedEntity),
err: certs.ErrFailedCertCreation,
},
{
desc: "issue new cert for non existing thing id",
desc: "issue new cert for invalid token",
token: invalid,
thingID: thingID,
ttl: ttl,
err: svcerr.ErrAuthentication,
pki: pki.Cert{
ClientCert: "",
IssuingCA: "",
CAChain: []string{},
ClientKey: "",
PrivateKeyType: "",
Serial: "",
Expire: 0,
},
identifyRes: &magistrala.IdentityRes{Id: validID},
identifyErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, errors.NewSDKError(tc.err))
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
sdkCall := sdk.On("Thing", tc.thingID, tc.token).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, tc.thingErr)
agentCall := agent.On("IssueCert", thingKey, tc.ttl).Return(tc.pki, tc.issueCertErr)
repoCall := repo.On("Save", context.Background(), mock.Anything).Return("", tc.repoErr)
c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
cert, _ := certs.ReadCert([]byte(c.ClientCert))
if cert != nil {
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, thingKey, cert.Subject.CommonName))
}
authCall.Unset()
sdkCall.Unset()
agentCall.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}
func TestRevokeCert(t *testing.T) {
svc, auth, sdk := newService(t)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
svc, repo, _, auth, sdk := newService(t)
cases := []struct {
token string
desc string
thingID string
err error
token string
desc string
thingID string
page certs.Page
identifyRes *magistrala.IdentityRes
identifyErr error
authErr error
thingErr errors.SDKError
repoErr error
err error
}{
{
desc: "revoke cert",
token: token,
thingID: thingID,
err: nil,
desc: "revoke cert",
token: token,
thingID: thingID,
page: certs.Page{Limit: 10000, Offset: 0, Total: 1, Certs: []certs.Cert{cert}},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "revoke cert for invalid token",
token: invalid,
thingID: thingID,
err: svcerr.ErrAuthentication,
desc: "revoke cert for invalid token",
token: invalid,
thingID: thingID,
page: certs.Page{},
identifyRes: &magistrala.IdentityRes{Id: validID},
identifyErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "revoke cert for invalid thing id",
token: token,
thingID: "2",
err: certs.ErrFailedCertRevocation,
desc: "revoke cert for invalid thing id",
token: token,
thingID: "2",
page: certs.Page{},
identifyRes: &magistrala.IdentityRes{Id: validID},
thingErr: errors.NewSDKError(certs.ErrFailedCertCreation),
err: certs.ErrFailedCertRevocation,
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, errors.NewSDKError(tc.err))
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
authCall1 := auth.On("Authorize", context.Background(), mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.authErr)
sdkCall := sdk.On("Thing", tc.thingID, tc.token).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, tc.thingErr)
repoCall := repo.On("RetrieveByThing", context.Background(), validID, tc.thingID, tc.page.Offset, tc.page.Limit).Return(certs.Page{}, tc.repoErr)
_, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
authCall.Unset()
authCall1.Unset()
sdkCall.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}
func TestListCerts(t *testing.T) {
svc, auth, sdk := newService(t)
svc, repo, agent, auth, _ := newService(t)
var mycerts []certs.Cert
for i := 0; i < certNum; i++ {
c := certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: fmt.Sprintf("%d", i),
Expire: time.Now().Add(time.Hour),
}
mycerts = append(mycerts, c)
}
for i := 0; i < certNum; i++ {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
agent.On("Read", fmt.Sprintf("%d", i)).Return(pki.Cert{}, nil)
}
cases := []struct {
token string
desc string
thingID string
offset uint64
limit uint64
size uint64
err error
token string
desc string
thingID string
page certs.Page
cert certs.Cert
identifyRes *magistrala.IdentityRes
identifyErr error
repoErr error
err error
}{
{
desc: "list all certs with valid token",
token: token,
thingID: thingID,
offset: 0,
limit: certNum,
size: certNum,
err: nil,
page: certs.Page{Limit: certNum, Offset: 0, Total: certNum, Certs: mycerts},
cert: certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: "0",
Expire: time.Now().Add(time.Hour),
},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "list all certs with invalid token",
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
size: 0,
err: svcerr.ErrAuthentication,
page: certs.Page{},
cert: certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: fmt.Sprintf("%d", certNum-1),
Expire: time.Now().Add(time.Hour),
},
identifyRes: &magistrala.IdentityRes{Id: validID},
identifyErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "list half certs with valid token",
token: token,
thingID: thingID,
offset: certNum / 2,
limit: certNum,
size: certNum / 2,
err: nil,
page: certs.Page{Limit: certNum, Offset: certNum / 2, Total: certNum / 2, Certs: mycerts[certNum/2:]},
cert: certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: fmt.Sprintf("%d", certNum/2),
Expire: time.Now().Add(time.Hour),
},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "list last cert with valid token",
token: token,
thingID: thingID,
offset: certNum - 1,
limit: certNum,
size: 1,
err: nil,
page: certs.Page{Limit: certNum, Offset: certNum - 1, Total: 1, Certs: []certs.Cert{mycerts[certNum-1]}},
cert: certs.Cert{
OwnerID: validID,
ThingID: thingID,
Serial: fmt.Sprintf("%d", certNum-1),
Expire: time.Now().Add(time.Hour),
},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
repoCall := repo.On("RetrieveByThing", context.Background(), validID, thingID, tc.page.Offset, tc.page.Limit).Return(tc.page, tc.repoErr)
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.page.Offset, tc.page.Limit)
size := uint64(len(page.Certs))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
assert.Equal(t, tc.page.Total, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.page.Total, size))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
authCall.Unset()
repoCall.Unset()
}
}
func TestListSerials(t *testing.T) {
svc, auth, sdk := newService(t)
svc, repo, _, auth, _ := newService(t)
var issuedCerts []certs.Cert
for i := 0; i < certNum; i++ {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
cert, err := svc.IssueCert(context.Background(), token, thingID, ttl)
assert.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
crt := certs.Cert{
OwnerID: cert.OwnerID,
ThingID: cert.ThingID,
@@ -254,72 +310,83 @@ func TestListSerials(t *testing.T) {
}
cases := []struct {
token string
desc string
thingID string
offset uint64
limit uint64
certs []certs.Cert
err error
token string
desc string
thingID string
offset uint64
limit uint64
certs []certs.Cert
identifyRes *magistrala.IdentityRes
identifyErr error
repoErr error
err error
}{
{
desc: "list all certs with valid token",
token: token,
thingID: thingID,
offset: 0,
limit: certNum,
certs: issuedCerts,
err: nil,
desc: "list all certs with valid token",
token: token,
thingID: thingID,
offset: 0,
limit: certNum,
certs: issuedCerts,
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "list all certs with invalid token",
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
certs: nil,
err: svcerr.ErrAuthentication,
desc: "list all certs with invalid token",
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
certs: nil,
identifyRes: &magistrala.IdentityRes{Id: validID},
identifyErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "list half certs with valid token",
token: token,
thingID: thingID,
offset: certNum / 2,
limit: certNum,
certs: issuedCerts[certNum/2:],
err: nil,
desc: "list half certs with valid token",
token: token,
thingID: thingID,
offset: certNum / 2,
limit: certNum,
certs: issuedCerts[certNum/2:],
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "list last cert with valid token",
token: token,
thingID: thingID,
offset: certNum - 1,
limit: certNum,
certs: []certs.Cert{issuedCerts[certNum-1]},
err: nil,
desc: "list last cert with valid token",
token: token,
thingID: thingID,
offset: certNum - 1,
limit: certNum,
certs: []certs.Cert{issuedCerts[certNum-1]},
identifyRes: &magistrala.IdentityRes{Id: validID},
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
repoCall := repo.On("RetrieveByThing", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certs.Page{Limit: tc.limit, Offset: tc.offset, Total: certNum, Certs: tc.certs}, tc.repoErr)
page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
authCall.Unset()
repoCall.Unset()
}
}
func TestViewCert(t *testing.T) {
svc, auth, sdk := newService(t)
svc, repo, agent, auth, sdk := newService(t)
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
sdkCall := sdk.On("Thing", thingID, token).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
agentCall := agent.On("IssueCert", thingKey, ttl).Return(pki.Cert{}, nil)
repoCall := repo.On("Save", context.Background(), mock.Anything).Return("", nil)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
ic, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
authCall.Unset()
sdkCall.Unset()
agentCall.Unset()
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
cert := certs.Cert{
ThingID: thingID,
@@ -329,40 +396,53 @@ func TestViewCert(t *testing.T) {
}
cases := []struct {
token string
desc string
serialID string
cert certs.Cert
err error
token string
desc string
serialID string
cert certs.Cert
identifyRes *magistrala.IdentityRes
identifyErr error
repoErr error
agentErr error
err error
}{
{
desc: "list cert with valid token and serial",
token: token,
serialID: cert.Serial,
cert: cert,
err: nil,
desc: "list cert with valid token and serial",
token: token,
serialID: cert.Serial,
cert: cert,
identifyRes: &magistrala.IdentityRes{Id: validID},
},
{
desc: "list cert with invalid token",
token: invalid,
serialID: cert.Serial,
cert: certs.Cert{},
err: svcerr.ErrAuthentication,
desc: "list cert with invalid token",
token: invalid,
serialID: cert.Serial,
cert: certs.Cert{},
identifyRes: &magistrala.IdentityRes{Id: validID},
identifyErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "list cert with invalid serial",
token: token,
serialID: invalid,
cert: certs.Cert{},
err: svcerr.ErrNotFound,
desc: "list cert with invalid serial",
token: token,
serialID: invalid,
cert: certs.Cert{},
identifyRes: &magistrala.IdentityRes{Id: validID},
repoErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
repoCall := repo.On("RetrieveBySerial", context.Background(), validID, tc.serialID).Return(tc.cert, tc.repoErr)
agentCall := agent.On("Read", tc.serialID).Return(pki.Cert{}, tc.agentErr)
cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID)
assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
authCall.Unset()
repoCall.Unset()
agentCall.Unset()
}
}
+159 -182
View File
@@ -10,65 +10,46 @@ import (
"testing"
"time"
"github.com/absmach/magistrala"
authmocks "github.com/absmach/magistrala/auth/mocks"
"github.com/absmach/magistrala/certs"
httpapi "github.com/absmach/magistrala/certs/api"
"github.com/absmach/magistrala/certs/mocks"
"github.com/absmach/magistrala/internal/apiutil"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
sdk "github.com/absmach/magistrala/pkg/sdk/go"
thmocks "github.com/absmach/magistrala/things/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
var (
thingID = "1"
caPath = "../../../docker/ssl/certs/ca.crt"
caKeyPath = "../../../docker/ssl/certs/ca.key"
cfgAuthTimeout = "1s"
cfgSignHoursValid = "24h"
)
var thingID = "1"
func setupCerts() (*httptest.Server, *authmocks.AuthClient, *thmocks.Repository, error) {
server, trepo, _, auth, _ := setupThings()
config := sdk.Config{
ThingsURL: server.URL,
}
var c = certs.Cert{
OwnerID: "",
ThingID: thingID,
ClientCert: "",
IssuingCA: "",
CAChain: []string{},
ClientKey: "",
PrivateKeyType: "",
Serial: "",
Expire: time.Time{},
}
mgsdk := sdk.NewSDK(config)
repo := mocks.NewCertsRepository()
tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
if err != nil {
return nil, auth, trepo, err
}
authTimeout, err := time.ParseDuration(cfgAuthTimeout)
if err != nil {
return nil, auth, trepo, err
}
pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)
svc := certs.New(auth, repo, mgsdk, pki)
func setupCerts() (*httptest.Server, *mocks.Service) {
svc := new(mocks.Service)
logger := mglog.NewMock()
mux := httpapi.MakeHandler(svc, logger, instanceID)
return httptest.NewServer(mux), auth, trepo, nil
return httptest.NewServer(mux), svc
}
func TestIssueCert(t *testing.T) {
ts, auth, trepo, err := setupCerts()
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
ts, svc := setupCerts()
defer ts.Close()
sdkConf := sdk.Config{
@@ -84,84 +65,93 @@ func TestIssueCert(t *testing.T) {
thingID string
duration string
token string
cRes certs.Cert
err errors.SDKError
svcerr error
}{
{
desc: "create new cert with thing id and duration",
thingID: thingID,
duration: "10h",
token: validToken,
err: nil,
cRes: c,
},
{
desc: "create new cert with empty thing id and duration",
thingID: "",
duration: "10h",
token: validToken,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrMissingID),
},
{
desc: "create new cert with invalid thing id and duration",
thingID: "ah",
duration: "10h",
token: validToken,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, certs.ErrFailedCertCreation), http.StatusBadRequest),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrValidation),
},
{
desc: "create new cert with thing id and empty duration",
thingID: thingID,
duration: "",
token: exampleUser1,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingCertData), http.StatusBadRequest),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrMissingCertData),
},
{
desc: "create new cert with thing id and malformed duration",
thingID: thingID,
duration: "10g",
token: exampleUser1,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidCertData), http.StatusBadRequest),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrInvalidCertData),
},
{
desc: "create new cert with empty token",
thingID: thingID,
duration: "10h",
token: "",
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, svcerr.ErrAuthentication),
},
{
desc: "create new cert with invalid token",
thingID: thingID,
duration: "10h",
token: authmocks.InvalidValue,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrAuthentication), http.StatusUnauthorized),
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, certs.ErrFailedCertCreation), http.StatusUnauthorized),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, svcerr.ErrAuthentication),
},
{
desc: "create new empty cert",
thingID: "",
duration: "",
token: validToken,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
svcerr: errors.Wrap(certs.ErrFailedCertCreation, certs.ErrFailedCertCreation),
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, tc.err)
cert, err := mgsdk.IssueCert(tc.thingID, tc.duration, tc.token)
svcCall := svc.On("IssueCert", mock.Anything, tc.token, tc.thingID, tc.duration).Return(tc.cRes, tc.svcerr)
_, err := mgsdk.IssueCert(tc.thingID, tc.duration, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
svcCall.Unset()
}
}
func TestViewCert(t *testing.T) {
ts, auth, trepo, err := setupCerts()
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
ts, svc := setupCerts()
defer ts.Close()
sdkConf := sdk.Config{
@@ -172,59 +162,52 @@ func TestViewCert(t *testing.T) {
mgsdk := sdk.NewSDK(sdkConf)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
cert, err := mgsdk.IssueCert(thingID, "10h", token)
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
cases := []struct {
desc string
certID string
token string
err errors.SDKError
response sdk.Subscription
desc string
certID string
token string
err errors.SDKError
svcerr error
cRes certs.Cert
}{
{
desc: "get existing cert",
certID: cert.CertSerial,
token: token,
err: nil,
response: sub1,
desc: "get existing cert",
certID: validID,
token: token,
cRes: c,
},
{
desc: "get non-existent cert",
certID: "43",
token: token,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrNotFound), http.StatusNotFound),
response: sdk.Subscription{},
desc: "get non-existent cert",
certID: "43",
token: token,
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrNotFound), http.StatusNotFound),
svcerr: errors.Wrap(svcerr.ErrNotFound, repoerr.ErrNotFound),
},
{
desc: "get cert with invalid token",
certID: cert.CertSerial,
token: "",
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
response: sdk.Subscription{},
desc: "get cert with invalid token",
certID: validID,
token: "",
cRes: c,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
svcCall := svc.On("ViewCert", mock.Anything, tc.token, tc.certID).Return(tc.cRes, tc.svcerr)
cert, err := mgsdk.ViewCert(tc.certID, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
}
repoCall.Unset()
svcCall.Unset()
}
}
func TestViewCertByThing(t *testing.T) {
ts, auth, trepo, err := setupCerts()
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
ts, svc := setupCerts()
defer ts.Close()
sdkConf := sdk.Config{
@@ -235,127 +218,121 @@ func TestViewCertByThing(t *testing.T) {
mgsdk := sdk.NewSDK(sdkConf)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
_, err = mgsdk.IssueCert(thingID, "10h", token)
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
cases := []struct {
desc string
thingID string
token string
err errors.SDKError
response sdk.Subscription
}{
{
desc: "get existing cert",
thingID: thingID,
token: token,
err: nil,
response: sub1,
},
{
desc: "get non-existent cert",
thingID: "43",
token: token,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, repoerr.ErrNotFound), http.StatusNotFound),
response: sdk.Subscription{},
},
{
desc: "get cert with invalid token",
thingID: thingID,
token: "",
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
response: sdk.Subscription{},
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
cert, err := mgsdk.ViewCertByThing(tc.thingID, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
}
repoCall.Unset()
}
}
func TestRevokeCert(t *testing.T) {
ts, auth, trepo, err := setupCerts()
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
defer ts.Close()
sdkConf := sdk.Config{
CertsURL: ts.URL,
MsgContentType: contentType,
TLSVerification: false,
}
mgsdk := sdk.NewSDK(sdkConf)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
_, err = mgsdk.IssueCert(thingID, "10h", validToken)
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
cases := []struct {
desc string
thingID string
token string
page certs.Page
err errors.SDKError
viewerr errors.SDKError
svcerr error
}{
{
desc: "revoke cert with invalid token",
desc: "get existing cert",
thingID: thingID,
token: authmocks.InvalidValue,
err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication), http.StatusUnauthorized),
},
{
desc: "revoke non-existing cert",
thingID: "2",
token: token,
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound), http.StatusNotFound),
page: certs.Page{Certs: []certs.Cert{c}},
},
{
desc: "revoke cert with empty token",
desc: "get non-existent cert",
thingID: "43",
token: token,
page: certs.Page{Certs: []certs.Cert{}},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, repoerr.ErrNotFound), http.StatusNotFound),
svcerr: errors.Wrap(svcerr.ErrNotFound, repoerr.ErrNotFound),
viewerr: errors.NewSDKError(svcerr.ErrViewEntity),
},
{
desc: "get cert with invalid token",
thingID: thingID,
token: "",
page: certs.Page{Certs: []certs.Cert{}},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
},
{
desc: "revoke existing cert",
thingID: thingID,
token: token,
err: nil,
},
{
desc: "revoke deleted cert",
thingID: thingID,
token: token,
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound), http.StatusNotFound),
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
},
}
for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, nil)
svcCall := svc.On("ListSerials", mock.Anything, tc.token, tc.thingID, tc.page.Offset, mock.Anything).Return(tc.page, tc.svcerr)
svcCall1 := svc.On("ViewCertByThing", mock.Anything, tc.thingID, tc.token).Return(tc.page, tc.viewerr)
cert, err := mgsdk.ViewCertByThing(tc.thingID, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
}
svcCall.Unset()
svcCall1.Unset()
}
}
func TestRevokeCert(t *testing.T) {
ts, svc := setupCerts()
defer ts.Close()
sdkConf := sdk.Config{
CertsURL: ts.URL,
MsgContentType: contentType,
TLSVerification: false,
}
mgsdk := sdk.NewSDK(sdkConf)
cases := []struct {
desc string
page certs.Page
thingID string
token string
svcResponse certs.Revoke
err errors.SDKError
svcerr error
}{
{
desc: "revoke cert with invalid token",
thingID: thingID,
token: authmocks.InvalidValue,
svcResponse: certs.Revoke{RevocationTime: time.Now()},
err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication), http.StatusUnauthorized),
svcerr: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
},
{
desc: "revoke non-existing cert",
thingID: "2",
token: token,
svcResponse: certs.Revoke{RevocationTime: time.Now()},
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound), http.StatusNotFound),
svcerr: errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound),
},
{
desc: "revoke cert with empty token",
thingID: thingID,
token: "",
svcResponse: certs.Revoke{RevocationTime: time.Now()},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
},
{
desc: "revoke existing cert",
thingID: thingID,
token: token,
svcResponse: certs.Revoke{RevocationTime: time.Now()},
},
{
desc: "revoke deleted cert",
thingID: thingID,
token: token,
svcResponse: certs.Revoke{RevocationTime: time.Now()},
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound), http.StatusNotFound),
svcerr: errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound),
},
}
for _, tc := range cases {
svcCall := svc.On("RevokeCert", mock.Anything, tc.token, tc.thingID).Return(tc.svcResponse, tc.svcerr)
response, err := mgsdk.RevokeCert(tc.thingID, tc.token)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, response, fmt.Sprintf("%s: got empty revocation time", tc.desc))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
svcCall.Unset()
}
}
+1 -3
View File
@@ -11,7 +11,6 @@ import (
"github.com/absmach/magistrala/pkg/errors"
sdk "github.com/absmach/magistrala/pkg/sdk/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHealth(t *testing.T) {
@@ -23,8 +22,7 @@ func TestHealth(t *testing.T) {
auth.Test(t)
defer usclsv.Close()
CertTs, _, _, err := setupCerts()
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
CertTs, _ := setupCerts()
defer CertTs.Close()
sdkConf := sdk.Config{