mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
MG-2137 - Generate mocks with mockery for Certs service (#2138)
Signed-off-by: JeffMboya <jangina.mboya@gmail.com>
This commit is contained in:
@@ -69,6 +69,9 @@ jobs:
|
||||
- "consumers/notifiers/notifier.go"
|
||||
- "consumers/notifiers/service.go"
|
||||
- "consumers/notifiers/subscriptions.go"
|
||||
- "certs/certs.go"
|
||||
- "certs/pki/vault.go"
|
||||
- "certs/service.go"
|
||||
|
||||
- name: Set up protoc
|
||||
if: steps.changes.outputs.proto == 'true'
|
||||
@@ -144,6 +147,9 @@ jobs:
|
||||
mv ./consumers/notifiers/mocks/notifier.go ./consumers/notifiers/mocks/notifier.go.tmp
|
||||
mv ./consumers/notifiers/mocks/service.go ./consumers/notifiers/mocks/service.go.tmp
|
||||
mv ./consumers/notifiers/mocks/repository.go ./consumers/notifiers/mocks/repository.go.tmp
|
||||
mv ./certs/mocks/certs.go ./certs/mocks/certs.go.tmp
|
||||
mv ./certs/mocks/pki.go ./certs/mocks/pki.go.tmp
|
||||
mv ./certs/mocks/service.go ./certs/mocks/service.go.tmp
|
||||
|
||||
make mocks
|
||||
|
||||
@@ -188,3 +194,6 @@ jobs:
|
||||
check_mock_changes ./consumers/notifiers/mocks/notifier.go "Notifiers Notifier ./consumers/notifiers/mocks/notifier.go"
|
||||
check_mock_changes ./consumers/notifiers/mocks/service.go "Notifiers Service ./consumers/notifiers/mocks/service.go"
|
||||
check_mock_changes ./consumers/notifiers/mocks/repository.go "Notifiers Repository ./consumers/notifiers/mocks/repository.go"
|
||||
check_mock_changes ./certs/mocks/certs.go "Certs Repository ./certs/mocks/certs.go"
|
||||
check_mock_changes ./certs/mocks/pki.go "PKI ./certs/mocks/pki.go"
|
||||
check_mock_changes ./certs/mocks/service.go "Certs Service ./certs/mocks/service.go"
|
||||
|
||||
@@ -24,6 +24,8 @@ type Page struct {
|
||||
var ErrMissingCerts = errors.New("CA path or CA key path not set")
|
||||
|
||||
// Repository specifies a Config persistence API.
|
||||
//
|
||||
//go:generate mockery --name Repository --output=./mocks --filename certs.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Repository interface {
|
||||
// Save saves cert for thing into database
|
||||
Save(ctx context.Context, cert Cert) (string, error)
|
||||
|
||||
+120
-95
@@ -1,137 +1,162 @@
|
||||
// Code generated by mockery v2.42.1. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
context "context"
|
||||
|
||||
"github.com/absmach/magistrala/certs"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
certs "github.com/absmach/magistrala/certs"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var _ certs.Repository = (*certsRepoMock)(nil)
|
||||
|
||||
type certsRepoMock struct {
|
||||
mu sync.Mutex
|
||||
counter uint64
|
||||
certsBySerial map[string]certs.Cert
|
||||
certsByThingID map[string]map[string][]certs.Cert
|
||||
// Repository is an autogenerated mock type for the Repository type
|
||||
type Repository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// NewCertsRepository creates in-memory certs repository.
|
||||
func NewCertsRepository() certs.Repository {
|
||||
return &certsRepoMock{
|
||||
certsBySerial: make(map[string]certs.Cert),
|
||||
certsByThingID: make(map[string]map[string][]certs.Cert),
|
||||
// Remove provides a mock function with given fields: ctx, ownerID, thingID
|
||||
func (_m *Repository) Remove(ctx context.Context, ownerID string, thingID string) error {
|
||||
ret := _m.Called(ctx, ownerID, thingID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Remove")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = rf(ctx, ownerID, thingID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
func (c *certsRepoMock) Save(ctx context.Context, cert certs.Cert) (string, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// RetrieveAll provides a mock function with given fields: ctx, ownerID, offset, limit
|
||||
func (_m *Repository) RetrieveAll(ctx context.Context, ownerID string, offset uint64, limit uint64) (certs.Page, error) {
|
||||
ret := _m.Called(ctx, ownerID, offset, limit)
|
||||
|
||||
crt := certs.Cert{
|
||||
OwnerID: cert.OwnerID,
|
||||
ThingID: cert.ThingID,
|
||||
Serial: cert.Serial,
|
||||
Expire: cert.Expire,
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveAll")
|
||||
}
|
||||
|
||||
_, ok := c.certsByThingID[cert.OwnerID][cert.ThingID]
|
||||
switch ok {
|
||||
case false:
|
||||
c.certsByThingID[cert.OwnerID] = map[string][]certs.Cert{
|
||||
cert.ThingID: {crt},
|
||||
}
|
||||
default:
|
||||
c.certsByThingID[cert.OwnerID][cert.ThingID] = append(c.certsByThingID[cert.OwnerID][cert.ThingID], crt)
|
||||
var r0 certs.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (certs.Page, error)); ok {
|
||||
return rf(ctx, ownerID, offset, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) certs.Page); ok {
|
||||
r0 = rf(ctx, ownerID, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Page)
|
||||
}
|
||||
|
||||
c.certsBySerial[cert.Serial] = crt
|
||||
c.counter++
|
||||
return cert.Serial, nil
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok {
|
||||
r1 = rf(ctx, ownerID, offset, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (c *certsRepoMock) RetrieveAll(ctx context.Context, ownerID string, offset, limit uint64) (certs.Page, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if limit <= 0 {
|
||||
return certs.Page{}, nil
|
||||
// RetrieveBySerial provides a mock function with given fields: ctx, ownerID, serialID
|
||||
func (_m *Repository) RetrieveBySerial(ctx context.Context, ownerID string, serialID string) (certs.Cert, error) {
|
||||
ret := _m.Called(ctx, ownerID, serialID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveBySerial")
|
||||
}
|
||||
|
||||
oc, ok := c.certsByThingID[ownerID]
|
||||
if !ok {
|
||||
return certs.Page{}, repoerr.ErrNotFound
|
||||
var r0 certs.Cert
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Cert, error)); ok {
|
||||
return rf(ctx, ownerID, serialID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Cert); ok {
|
||||
r0 = rf(ctx, ownerID, serialID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Cert)
|
||||
}
|
||||
|
||||
var crts []certs.Cert
|
||||
for _, tc := range oc {
|
||||
for i, v := range tc {
|
||||
if uint64(i) >= offset && uint64(i) < offset+limit {
|
||||
crts = append(crts, v)
|
||||
}
|
||||
}
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, ownerID, serialID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
page := certs.Page{
|
||||
Certs: crts,
|
||||
Total: c.counter,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
return page, nil
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (c *certsRepoMock) Remove(ctx context.Context, ownerID, serial string) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
crt, ok := c.certsBySerial[serial]
|
||||
if !ok {
|
||||
return repoerr.ErrNotFound
|
||||
// RetrieveByThing provides a mock function with given fields: ctx, ownerID, thingID, offset, limit
|
||||
func (_m *Repository) RetrieveByThing(ctx context.Context, ownerID string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
|
||||
ret := _m.Called(ctx, ownerID, thingID, offset, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByThing")
|
||||
}
|
||||
delete(c.certsBySerial, crt.Serial)
|
||||
delete(c.certsByThingID, crt.ThingID)
|
||||
return nil
|
||||
|
||||
var r0 certs.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
|
||||
return rf(ctx, ownerID, thingID, offset, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
|
||||
r0 = rf(ctx, ownerID, thingID, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
|
||||
r1 = rf(ctx, ownerID, thingID, offset, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (c *certsRepoMock) RetrieveByThing(ctx context.Context, ownerID, thingID string, offset, limit uint64) (certs.Page, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if limit <= 0 {
|
||||
return certs.Page{}, nil
|
||||
// Save provides a mock function with given fields: ctx, cert
|
||||
func (_m *Repository) Save(ctx context.Context, cert certs.Cert) (string, error) {
|
||||
ret := _m.Called(ctx, cert)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Save")
|
||||
}
|
||||
|
||||
cs, ok := c.certsByThingID[ownerID][thingID]
|
||||
if !ok {
|
||||
return certs.Page{}, repoerr.ErrNotFound
|
||||
var r0 string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, certs.Cert) (string, error)); ok {
|
||||
return rf(ctx, cert)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, certs.Cert) string); ok {
|
||||
r0 = rf(ctx, cert)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
var crts []certs.Cert
|
||||
for i, v := range cs {
|
||||
if uint64(i) >= offset && uint64(i) < offset+limit {
|
||||
crts = append(crts, v)
|
||||
}
|
||||
if rf, ok := ret.Get(1).(func(context.Context, certs.Cert) error); ok {
|
||||
r1 = rf(ctx, cert)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
page := certs.Page{
|
||||
Certs: crts,
|
||||
Total: c.counter,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
return page, nil
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (c *certsRepoMock) RetrieveBySerial(ctx context.Context, ownerID, serialID string) (certs.Cert, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewRepository(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Repository {
|
||||
mock := &Repository{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
crt, ok := c.certsBySerial[serialID]
|
||||
if !ok {
|
||||
return certs.Cert{}, repoerr.ErrNotFound
|
||||
}
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return crt, nil
|
||||
return mock
|
||||
}
|
||||
|
||||
+106
-167
@@ -1,196 +1,135 @@
|
||||
// Code generated by mockery v2.42.1. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"sync"
|
||||
"time"
|
||||
context "context"
|
||||
|
||||
"github.com/absmach/magistrala/certs/pki"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
pki "github.com/absmach/magistrala/certs/pki"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
time "time"
|
||||
)
|
||||
|
||||
const keyBits = 2048
|
||||
|
||||
var (
|
||||
errPrivateKeyEmpty = errors.New("private key is empty")
|
||||
errPrivateKeyUnsupportedType = errors.New("private key type is unsupported")
|
||||
)
|
||||
|
||||
var _ pki.Agent = (*agent)(nil)
|
||||
|
||||
type agent struct {
|
||||
AuthTimeout time.Duration
|
||||
TLSCert tls.Certificate
|
||||
X509Cert *x509.Certificate
|
||||
TTL string
|
||||
mu sync.Mutex
|
||||
counter uint64
|
||||
certs map[string]pki.Cert
|
||||
// Agent is an autogenerated mock type for the Agent type
|
||||
type Agent struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func NewPkiAgent(tlsCert tls.Certificate, caCert *x509.Certificate, ttl string, timeout time.Duration) pki.Agent {
|
||||
return &agent{
|
||||
AuthTimeout: timeout,
|
||||
TLSCert: tlsCert,
|
||||
X509Cert: caCert,
|
||||
TTL: ttl,
|
||||
certs: make(map[string]pki.Cert),
|
||||
// IssueCert provides a mock function with given fields: cn, ttl
|
||||
func (_m *Agent) IssueCert(cn string, ttl string) (pki.Cert, error) {
|
||||
ret := _m.Called(cn, ttl)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IssueCert")
|
||||
}
|
||||
|
||||
var r0 pki.Cert
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string, string) (pki.Cert, error)); ok {
|
||||
return rf(cn, ttl)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string, string) pki.Cert); ok {
|
||||
r0 = rf(cn, ttl)
|
||||
} else {
|
||||
r0 = ret.Get(0).(pki.Cert)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string, string) error); ok {
|
||||
r1 = rf(cn, ttl)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (a *agent) IssueCert(cn, ttl string) (pki.Cert, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
// LoginAndRenew provides a mock function with given fields: ctx
|
||||
func (_m *Agent) LoginAndRenew(ctx context.Context) error {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if a.X509Cert == nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, pki.ErrMissingCACertificate)
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for LoginAndRenew")
|
||||
}
|
||||
|
||||
var priv interface{}
|
||||
priv, err := rsa.GenerateKey(rand.Reader, keyBits)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context) error); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
if ttl == "" {
|
||||
ttl = a.TTL
|
||||
}
|
||||
|
||||
notBefore := time.Now()
|
||||
validFor, err := time.ParseDuration(ttl)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
notAfter := notBefore.Add(validFor)
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
|
||||
tmpl := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Magistrala"},
|
||||
CommonName: cn,
|
||||
OrganizationalUnit: []string{"magistrala"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
||||
}
|
||||
|
||||
pubKey, err := publicKey(priv)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, a.X509Cert, pubKey, a.TLSCert.PrivateKey)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
|
||||
x509cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
|
||||
var bw, keyOut bytes.Buffer
|
||||
buffWriter := bufio.NewWriter(&bw)
|
||||
buffKeyOut := bufio.NewWriter(&keyOut)
|
||||
|
||||
if err := pem.Encode(buffWriter, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
buffWriter.Flush()
|
||||
cert := bw.String()
|
||||
|
||||
block, err := pemBlockForKey(priv)
|
||||
if err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
if err := pem.Encode(buffKeyOut, block); err != nil {
|
||||
return pki.Cert{}, errors.Wrap(pki.ErrFailedCertCreation, err)
|
||||
}
|
||||
buffKeyOut.Flush()
|
||||
key := keyOut.String()
|
||||
|
||||
a.certs[x509cert.SerialNumber.String()] = pki.Cert{
|
||||
ClientCert: cert,
|
||||
}
|
||||
a.counter++
|
||||
|
||||
return pki.Cert{
|
||||
ClientCert: cert,
|
||||
ClientKey: key,
|
||||
Serial: x509cert.SerialNumber.String(),
|
||||
Expire: x509cert.NotAfter.Unix(),
|
||||
IssuingCA: x509cert.Issuer.String(),
|
||||
}, nil
|
||||
return r0
|
||||
}
|
||||
|
||||
func (a *agent) Read(serial string) (pki.Cert, error) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
// Read provides a mock function with given fields: serial
|
||||
func (_m *Agent) Read(serial string) (pki.Cert, error) {
|
||||
ret := _m.Called(serial)
|
||||
|
||||
crt, ok := a.certs[serial]
|
||||
if !ok {
|
||||
return pki.Cert{}, repoerr.ErrNotFound
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Read")
|
||||
}
|
||||
|
||||
return crt, nil
|
||||
}
|
||||
|
||||
func (a *agent) Revoke(serial string) (time.Time, error) {
|
||||
return time.Now(), nil
|
||||
}
|
||||
|
||||
func (a *agent) LoginAndRenew(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func publicKey(priv interface{}) (interface{}, error) {
|
||||
if priv == nil {
|
||||
return nil, errPrivateKeyEmpty
|
||||
var r0 pki.Cert
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) (pki.Cert, error)); ok {
|
||||
return rf(serial)
|
||||
}
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey, nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey, nil
|
||||
default:
|
||||
return nil, errPrivateKeyUnsupportedType
|
||||
if rf, ok := ret.Get(0).(func(string) pki.Cert); ok {
|
||||
r0 = rf(serial)
|
||||
} else {
|
||||
r0 = ret.Get(0).(pki.Cert)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(serial)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func pemBlockForKey(priv interface{}) (*pem.Block, error) {
|
||||
switch k := priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)}, nil
|
||||
case *ecdsa.PrivateKey:
|
||||
b, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}, nil
|
||||
default:
|
||||
return nil, nil
|
||||
// Revoke provides a mock function with given fields: serial
|
||||
func (_m *Agent) Revoke(serial string) (time.Time, error) {
|
||||
ret := _m.Called(serial)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Revoke")
|
||||
}
|
||||
|
||||
var r0 time.Time
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) (time.Time, error)); ok {
|
||||
return rf(serial)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string) time.Time); ok {
|
||||
r0 = rf(serial)
|
||||
} else {
|
||||
r0 = ret.Get(0).(time.Time)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(serial)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewAgent creates a new instance of Agent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAgent(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Agent {
|
||||
mock := &Agent{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
// Code generated by mockery v2.42.1. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
certs "github.com/absmach/magistrala/certs"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// IssueCert provides a mock function with given fields: ctx, token, thingID, ttl
|
||||
func (_m *Service) IssueCert(ctx context.Context, token string, thingID string, ttl string) (certs.Cert, error) {
|
||||
ret := _m.Called(ctx, token, thingID, ttl)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IssueCert")
|
||||
}
|
||||
|
||||
var r0 certs.Cert
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (certs.Cert, error)); ok {
|
||||
return rf(ctx, token, thingID, ttl)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, string) certs.Cert); ok {
|
||||
r0 = rf(ctx, token, thingID, ttl)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Cert)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
|
||||
r1 = rf(ctx, token, thingID, ttl)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListCerts provides a mock function with given fields: ctx, token, thingID, offset, limit
|
||||
func (_m *Service) ListCerts(ctx context.Context, token string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
|
||||
ret := _m.Called(ctx, token, thingID, offset, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListCerts")
|
||||
}
|
||||
|
||||
var r0 certs.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
|
||||
return rf(ctx, token, thingID, offset, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
|
||||
r0 = rf(ctx, token, thingID, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
|
||||
r1 = rf(ctx, token, thingID, offset, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListSerials provides a mock function with given fields: ctx, token, thingID, offset, limit
|
||||
func (_m *Service) ListSerials(ctx context.Context, token string, thingID string, offset uint64, limit uint64) (certs.Page, error) {
|
||||
ret := _m.Called(ctx, token, thingID, offset, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListSerials")
|
||||
}
|
||||
|
||||
var r0 certs.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) (certs.Page, error)); ok {
|
||||
return rf(ctx, token, thingID, offset, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, uint64, uint64) certs.Page); ok {
|
||||
r0 = rf(ctx, token, thingID, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, uint64, uint64) error); ok {
|
||||
r1 = rf(ctx, token, thingID, offset, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RevokeCert provides a mock function with given fields: ctx, token, serialID
|
||||
func (_m *Service) RevokeCert(ctx context.Context, token string, serialID string) (certs.Revoke, error) {
|
||||
ret := _m.Called(ctx, token, serialID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RevokeCert")
|
||||
}
|
||||
|
||||
var r0 certs.Revoke
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Revoke, error)); ok {
|
||||
return rf(ctx, token, serialID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Revoke); ok {
|
||||
r0 = rf(ctx, token, serialID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Revoke)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, token, serialID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ViewCert provides a mock function with given fields: ctx, token, serialID
|
||||
func (_m *Service) ViewCert(ctx context.Context, token string, serialID string) (certs.Cert, error) {
|
||||
ret := _m.Called(ctx, token, serialID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ViewCert")
|
||||
}
|
||||
|
||||
var r0 certs.Cert
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.Cert, error)); ok {
|
||||
return rf(ctx, token, serialID)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.Cert); ok {
|
||||
r0 = rf(ctx, token, serialID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(certs.Cert)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = rf(ctx, token, serialID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -54,6 +54,8 @@ type Cert struct {
|
||||
}
|
||||
|
||||
// Agent represents the Vault PKI interface.
|
||||
//
|
||||
//go:generate mockery --name Agent --output=../mocks --filename pki.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Agent interface {
|
||||
// IssueCert issues certificate on PKI
|
||||
IssueCert(cn, ttl string) (Cert, error)
|
||||
|
||||
@@ -31,6 +31,8 @@ var _ Service = (*certsService)(nil)
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
//
|
||||
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Service interface {
|
||||
// IssueCert issues certificate for given thing id if access is granted with token
|
||||
IssueCert(ctx context.Context, token, thingID, ttl string) (Cert, error)
|
||||
|
||||
+255
-175
@@ -14,7 +14,9 @@ import (
|
||||
authmocks "github.com/absmach/magistrala/auth/mocks"
|
||||
"github.com/absmach/magistrala/certs"
|
||||
"github.com/absmach/magistrala/certs/mocks"
|
||||
"github.com/absmach/magistrala/certs/pki"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
|
||||
sdkmocks "github.com/absmach/magistrala/pkg/sdk/mocks"
|
||||
@@ -33,217 +35,271 @@ const (
|
||||
ttl = "1h"
|
||||
certNum = 10
|
||||
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"
|
||||
|
||||
cfgAuthTimeout = "1s"
|
||||
|
||||
caPath = "../docker/ssl/certs/ca.crt"
|
||||
caKeyPath = "../docker/ssl/certs/ca.key"
|
||||
cfgSignHoursValid = "24h"
|
||||
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
|
||||
)
|
||||
|
||||
func newService(t *testing.T) (certs.Service, *authmocks.AuthClient, *sdkmocks.SDK) {
|
||||
func newService(_ *testing.T) (certs.Service, *mocks.Repository, *mocks.Agent, *authmocks.AuthClient, *sdkmocks.SDK) {
|
||||
repo := new(mocks.Repository)
|
||||
agent := new(mocks.Agent)
|
||||
auth := new(authmocks.AuthClient)
|
||||
|
||||
sdk := new(sdkmocks.SDK)
|
||||
repo := mocks.NewCertsRepository()
|
||||
|
||||
tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected cert loading error: %s\n", err))
|
||||
return certs.New(auth, repo, sdk, agent), repo, agent, auth, sdk
|
||||
}
|
||||
|
||||
authTimeout, err := time.ParseDuration(cfgAuthTimeout)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected auth timeout parsing error: %s\n", err))
|
||||
|
||||
pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)
|
||||
|
||||
return certs.New(auth, repo, sdk, pki), auth, sdk
|
||||
var cert = certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: "",
|
||||
Expire: time.Time{},
|
||||
}
|
||||
|
||||
func TestIssueCert(t *testing.T) {
|
||||
svc, auth, sdk := newService(t)
|
||||
|
||||
svc, repo, agent, auth, sdk := newService(t)
|
||||
cases := []struct {
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
ttl string
|
||||
key string
|
||||
err error
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
ttl string
|
||||
key string
|
||||
pki pki.Cert
|
||||
identifyRes *magistrala.IdentityRes
|
||||
identifyErr error
|
||||
thingErr errors.SDKError
|
||||
issueCertErr error
|
||||
repoErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "issue new cert",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
ttl: ttl,
|
||||
err: nil,
|
||||
pki: pki.Cert{
|
||||
ClientCert: "",
|
||||
IssuingCA: "",
|
||||
CAChain: []string{},
|
||||
ClientKey: "",
|
||||
PrivateKeyType: "",
|
||||
Serial: "",
|
||||
Expire: 0,
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "issue new cert for non existing thing id",
|
||||
token: token,
|
||||
thingID: "2",
|
||||
ttl: ttl,
|
||||
err: certs.ErrFailedCertCreation,
|
||||
pki: pki.Cert{
|
||||
ClientCert: "",
|
||||
IssuingCA: "",
|
||||
CAChain: []string{},
|
||||
ClientKey: "",
|
||||
PrivateKeyType: "",
|
||||
Serial: "",
|
||||
Expire: 0,
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
thingErr: errors.NewSDKError(errors.ErrMalformedEntity),
|
||||
err: certs.ErrFailedCertCreation,
|
||||
},
|
||||
{
|
||||
desc: "issue new cert for non existing thing id",
|
||||
desc: "issue new cert for invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
ttl: ttl,
|
||||
err: svcerr.ErrAuthentication,
|
||||
pki: pki.Cert{
|
||||
ClientCert: "",
|
||||
IssuingCA: "",
|
||||
CAChain: []string{},
|
||||
ClientKey: "",
|
||||
PrivateKeyType: "",
|
||||
Serial: "",
|
||||
Expire: 0,
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
identifyErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, errors.NewSDKError(tc.err))
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
|
||||
sdkCall := sdk.On("Thing", tc.thingID, tc.token).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, tc.thingErr)
|
||||
agentCall := agent.On("IssueCert", thingKey, tc.ttl).Return(tc.pki, tc.issueCertErr)
|
||||
repoCall := repo.On("Save", context.Background(), mock.Anything).Return("", tc.repoErr)
|
||||
|
||||
c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
cert, _ := certs.ReadCert([]byte(c.ClientCert))
|
||||
if cert != nil {
|
||||
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, thingKey, cert.Subject.CommonName))
|
||||
}
|
||||
authCall.Unset()
|
||||
sdkCall.Unset()
|
||||
agentCall.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeCert(t *testing.T) {
|
||||
svc, auth, sdk := newService(t)
|
||||
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
|
||||
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
svc, repo, _, auth, sdk := newService(t)
|
||||
cases := []struct {
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
err error
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
page certs.Page
|
||||
identifyRes *magistrala.IdentityRes
|
||||
identifyErr error
|
||||
authErr error
|
||||
thingErr errors.SDKError
|
||||
repoErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke cert",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
err: nil,
|
||||
desc: "revoke cert",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
page: certs.Page{Limit: 10000, Offset: 0, Total: 1, Certs: []certs.Cert{cert}},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "revoke cert for invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "revoke cert for invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
page: certs.Page{},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
identifyErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "revoke cert for invalid thing id",
|
||||
token: token,
|
||||
thingID: "2",
|
||||
err: certs.ErrFailedCertRevocation,
|
||||
desc: "revoke cert for invalid thing id",
|
||||
token: token,
|
||||
thingID: "2",
|
||||
page: certs.Page{},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
thingErr: errors.NewSDKError(certs.ErrFailedCertCreation),
|
||||
err: certs.ErrFailedCertRevocation,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, errors.NewSDKError(tc.err))
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
|
||||
authCall1 := auth.On("Authorize", context.Background(), mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.authErr)
|
||||
sdkCall := sdk.On("Thing", tc.thingID, tc.token).Return(mgsdk.Thing{ID: tc.thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, tc.thingErr)
|
||||
repoCall := repo.On("RetrieveByThing", context.Background(), validID, tc.thingID, tc.page.Offset, tc.page.Limit).Return(certs.Page{}, tc.repoErr)
|
||||
|
||||
_, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
authCall.Unset()
|
||||
authCall1.Unset()
|
||||
sdkCall.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCerts(t *testing.T) {
|
||||
svc, auth, sdk := newService(t)
|
||||
svc, repo, agent, auth, _ := newService(t)
|
||||
var mycerts []certs.Cert
|
||||
for i := 0; i < certNum; i++ {
|
||||
c := certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: fmt.Sprintf("%d", i),
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
}
|
||||
mycerts = append(mycerts, c)
|
||||
}
|
||||
|
||||
for i := 0; i < certNum; i++ {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
|
||||
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
agent.On("Read", fmt.Sprintf("%d", i)).Return(pki.Cert{}, nil)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
offset uint64
|
||||
limit uint64
|
||||
size uint64
|
||||
err error
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
page certs.Page
|
||||
cert certs.Cert
|
||||
identifyRes *magistrala.IdentityRes
|
||||
identifyErr error
|
||||
repoErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list all certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
size: certNum,
|
||||
err: nil,
|
||||
page: certs.Page{Limit: certNum, Offset: 0, Total: certNum, Certs: mycerts},
|
||||
cert: certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: "0",
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "list all certs with invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
size: 0,
|
||||
err: svcerr.ErrAuthentication,
|
||||
page: certs.Page{},
|
||||
cert: certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: fmt.Sprintf("%d", certNum-1),
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
identifyErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list half certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum / 2,
|
||||
limit: certNum,
|
||||
size: certNum / 2,
|
||||
err: nil,
|
||||
page: certs.Page{Limit: certNum, Offset: certNum / 2, Total: certNum / 2, Certs: mycerts[certNum/2:]},
|
||||
cert: certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: fmt.Sprintf("%d", certNum/2),
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "list last cert with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum - 1,
|
||||
limit: certNum,
|
||||
size: 1,
|
||||
err: nil,
|
||||
page: certs.Page{Limit: certNum, Offset: certNum - 1, Total: 1, Certs: []certs.Cert{mycerts[certNum-1]}},
|
||||
cert: certs.Cert{
|
||||
OwnerID: validID,
|
||||
ThingID: thingID,
|
||||
Serial: fmt.Sprintf("%d", certNum-1),
|
||||
Expire: time.Now().Add(time.Hour),
|
||||
},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
|
||||
repoCall := repo.On("RetrieveByThing", context.Background(), validID, thingID, tc.page.Offset, tc.page.Limit).Return(tc.page, tc.repoErr)
|
||||
|
||||
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.page.Offset, tc.page.Limit)
|
||||
size := uint64(len(page.Certs))
|
||||
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
|
||||
assert.Equal(t, tc.page.Total, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.page.Total, size))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSerials(t *testing.T) {
|
||||
svc, auth, sdk := newService(t)
|
||||
svc, repo, _, auth, _ := newService(t)
|
||||
|
||||
var issuedCerts []certs.Cert
|
||||
for i := 0; i < certNum; i++ {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
|
||||
cert, err := svc.IssueCert(context.Background(), token, thingID, ttl)
|
||||
assert.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
crt := certs.Cert{
|
||||
OwnerID: cert.OwnerID,
|
||||
ThingID: cert.ThingID,
|
||||
@@ -254,72 +310,83 @@ func TestListSerials(t *testing.T) {
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
offset uint64
|
||||
limit uint64
|
||||
certs []certs.Cert
|
||||
err error
|
||||
token string
|
||||
desc string
|
||||
thingID string
|
||||
offset uint64
|
||||
limit uint64
|
||||
certs []certs.Cert
|
||||
identifyRes *magistrala.IdentityRes
|
||||
identifyErr error
|
||||
repoErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list all certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
certs: issuedCerts,
|
||||
err: nil,
|
||||
desc: "list all certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
certs: issuedCerts,
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "list all certs with invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
certs: nil,
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "list all certs with invalid token",
|
||||
token: invalid,
|
||||
thingID: thingID,
|
||||
offset: 0,
|
||||
limit: certNum,
|
||||
certs: nil,
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
identifyErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list half certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum / 2,
|
||||
limit: certNum,
|
||||
certs: issuedCerts[certNum/2:],
|
||||
err: nil,
|
||||
desc: "list half certs with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum / 2,
|
||||
limit: certNum,
|
||||
certs: issuedCerts[certNum/2:],
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "list last cert with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum - 1,
|
||||
limit: certNum,
|
||||
certs: []certs.Cert{issuedCerts[certNum-1]},
|
||||
err: nil,
|
||||
desc: "list last cert with valid token",
|
||||
token: token,
|
||||
thingID: thingID,
|
||||
offset: certNum - 1,
|
||||
limit: certNum,
|
||||
certs: []certs.Cert{issuedCerts[certNum-1]},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
|
||||
repoCall := repo.On("RetrieveByThing", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(certs.Page{Limit: tc.limit, Offset: tc.offset, Total: certNum, Certs: tc.certs}, tc.repoErr)
|
||||
|
||||
page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
|
||||
assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewCert(t *testing.T) {
|
||||
svc, auth, sdk := newService(t)
|
||||
svc, repo, agent, auth, sdk := newService(t)
|
||||
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
sdkCall := sdk.On("Thing", thingID, token).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
|
||||
agentCall := agent.On("IssueCert", thingKey, ttl).Return(pki.Cert{}, nil)
|
||||
repoCall := repo.On("Save", context.Background(), mock.Anything).Return("", nil)
|
||||
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := sdk.On("Thing", mock.Anything, mock.Anything).Return(mgsdk.Thing{ID: thingID, Credentials: mgsdk.Credentials{Secret: thingKey}}, nil)
|
||||
ic, err := svc.IssueCert(context.Background(), token, thingID, ttl)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
|
||||
authCall.Unset()
|
||||
sdkCall.Unset()
|
||||
agentCall.Unset()
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
cert := certs.Cert{
|
||||
ThingID: thingID,
|
||||
@@ -329,40 +396,53 @@ func TestViewCert(t *testing.T) {
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
token string
|
||||
desc string
|
||||
serialID string
|
||||
cert certs.Cert
|
||||
err error
|
||||
token string
|
||||
desc string
|
||||
serialID string
|
||||
cert certs.Cert
|
||||
identifyRes *magistrala.IdentityRes
|
||||
identifyErr error
|
||||
repoErr error
|
||||
agentErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list cert with valid token and serial",
|
||||
token: token,
|
||||
serialID: cert.Serial,
|
||||
cert: cert,
|
||||
err: nil,
|
||||
desc: "list cert with valid token and serial",
|
||||
token: token,
|
||||
serialID: cert.Serial,
|
||||
cert: cert,
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
},
|
||||
{
|
||||
desc: "list cert with invalid token",
|
||||
token: invalid,
|
||||
serialID: cert.Serial,
|
||||
cert: certs.Cert{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
desc: "list cert with invalid token",
|
||||
token: invalid,
|
||||
serialID: cert.Serial,
|
||||
cert: certs.Cert{},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
identifyErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list cert with invalid serial",
|
||||
token: token,
|
||||
serialID: invalid,
|
||||
cert: certs.Cert{},
|
||||
err: svcerr.ErrNotFound,
|
||||
desc: "list cert with invalid serial",
|
||||
token: token,
|
||||
serialID: invalid,
|
||||
cert: certs.Cert{},
|
||||
identifyRes: &magistrala.IdentityRes{Id: validID},
|
||||
repoErr: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
authCall := auth.On("Identify", context.Background(), &magistrala.IdentityReq{Token: tc.token}).Return(tc.identifyRes, tc.identifyErr)
|
||||
repoCall := repo.On("RetrieveBySerial", context.Background(), validID, tc.serialID).Return(tc.cert, tc.repoErr)
|
||||
agentCall := agent.On("Read", tc.serialID).Return(pki.Cert{}, tc.agentErr)
|
||||
|
||||
cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID)
|
||||
assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert))
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
authCall.Unset()
|
||||
repoCall.Unset()
|
||||
agentCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
+159
-182
@@ -10,65 +10,46 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
authmocks "github.com/absmach/magistrala/auth/mocks"
|
||||
"github.com/absmach/magistrala/certs"
|
||||
httpapi "github.com/absmach/magistrala/certs/api"
|
||||
"github.com/absmach/magistrala/certs/mocks"
|
||||
"github.com/absmach/magistrala/internal/apiutil"
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/clients"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
sdk "github.com/absmach/magistrala/pkg/sdk/go"
|
||||
thmocks "github.com/absmach/magistrala/things/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
|
||||
|
||||
var (
|
||||
thingID = "1"
|
||||
caPath = "../../../docker/ssl/certs/ca.crt"
|
||||
caKeyPath = "../../../docker/ssl/certs/ca.key"
|
||||
cfgAuthTimeout = "1s"
|
||||
cfgSignHoursValid = "24h"
|
||||
)
|
||||
var thingID = "1"
|
||||
|
||||
func setupCerts() (*httptest.Server, *authmocks.AuthClient, *thmocks.Repository, error) {
|
||||
server, trepo, _, auth, _ := setupThings()
|
||||
config := sdk.Config{
|
||||
ThingsURL: server.URL,
|
||||
}
|
||||
var c = certs.Cert{
|
||||
OwnerID: "",
|
||||
ThingID: thingID,
|
||||
ClientCert: "",
|
||||
IssuingCA: "",
|
||||
CAChain: []string{},
|
||||
ClientKey: "",
|
||||
PrivateKeyType: "",
|
||||
Serial: "",
|
||||
Expire: time.Time{},
|
||||
}
|
||||
|
||||
mgsdk := sdk.NewSDK(config)
|
||||
repo := mocks.NewCertsRepository()
|
||||
|
||||
tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
|
||||
if err != nil {
|
||||
return nil, auth, trepo, err
|
||||
}
|
||||
|
||||
authTimeout, err := time.ParseDuration(cfgAuthTimeout)
|
||||
if err != nil {
|
||||
return nil, auth, trepo, err
|
||||
}
|
||||
|
||||
pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)
|
||||
|
||||
svc := certs.New(auth, repo, mgsdk, pki)
|
||||
func setupCerts() (*httptest.Server, *mocks.Service) {
|
||||
svc := new(mocks.Service)
|
||||
logger := mglog.NewMock()
|
||||
mux := httpapi.MakeHandler(svc, logger, instanceID)
|
||||
|
||||
return httptest.NewServer(mux), auth, trepo, nil
|
||||
return httptest.NewServer(mux), svc
|
||||
}
|
||||
|
||||
func TestIssueCert(t *testing.T) {
|
||||
ts, auth, trepo, err := setupCerts()
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
|
||||
ts, svc := setupCerts()
|
||||
defer ts.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
@@ -84,84 +65,93 @@ func TestIssueCert(t *testing.T) {
|
||||
thingID string
|
||||
duration string
|
||||
token string
|
||||
cRes certs.Cert
|
||||
err errors.SDKError
|
||||
svcerr error
|
||||
}{
|
||||
{
|
||||
desc: "create new cert with thing id and duration",
|
||||
thingID: thingID,
|
||||
duration: "10h",
|
||||
token: validToken,
|
||||
err: nil,
|
||||
cRes: c,
|
||||
},
|
||||
{
|
||||
desc: "create new cert with empty thing id and duration",
|
||||
thingID: "",
|
||||
duration: "10h",
|
||||
token: validToken,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrMissingID),
|
||||
},
|
||||
{
|
||||
desc: "create new cert with invalid thing id and duration",
|
||||
thingID: "ah",
|
||||
duration: "10h",
|
||||
token: validToken,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, certs.ErrFailedCertCreation), http.StatusBadRequest),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrValidation),
|
||||
},
|
||||
{
|
||||
desc: "create new cert with thing id and empty duration",
|
||||
thingID: thingID,
|
||||
duration: "",
|
||||
token: exampleUser1,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingCertData), http.StatusBadRequest),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrMissingCertData),
|
||||
},
|
||||
{
|
||||
desc: "create new cert with thing id and malformed duration",
|
||||
thingID: thingID,
|
||||
duration: "10g",
|
||||
token: exampleUser1,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidCertData), http.StatusBadRequest),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, apiutil.ErrInvalidCertData),
|
||||
},
|
||||
{
|
||||
desc: "create new cert with empty token",
|
||||
thingID: thingID,
|
||||
duration: "10h",
|
||||
token: "",
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "create new cert with invalid token",
|
||||
thingID: thingID,
|
||||
duration: "10h",
|
||||
token: authmocks.InvalidValue,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrAuthentication), http.StatusUnauthorized),
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, certs.ErrFailedCertCreation), http.StatusUnauthorized),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "create new empty cert",
|
||||
thingID: "",
|
||||
duration: "",
|
||||
token: validToken,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertCreation, certs.ErrFailedCertCreation),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, tc.err)
|
||||
cert, err := mgsdk.IssueCert(tc.thingID, tc.duration, tc.token)
|
||||
svcCall := svc.On("IssueCert", mock.Anything, tc.token, tc.thingID, tc.duration).Return(tc.cRes, tc.svcerr)
|
||||
|
||||
_, err := mgsdk.IssueCert(tc.thingID, tc.duration, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewCert(t *testing.T) {
|
||||
ts, auth, trepo, err := setupCerts()
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
|
||||
ts, svc := setupCerts()
|
||||
defer ts.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
@@ -172,59 +162,52 @@ func TestViewCert(t *testing.T) {
|
||||
|
||||
mgsdk := sdk.NewSDK(sdkConf)
|
||||
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
|
||||
cert, err := mgsdk.IssueCert(thingID, "10h", token)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
certID string
|
||||
token string
|
||||
err errors.SDKError
|
||||
response sdk.Subscription
|
||||
desc string
|
||||
certID string
|
||||
token string
|
||||
err errors.SDKError
|
||||
svcerr error
|
||||
cRes certs.Cert
|
||||
}{
|
||||
{
|
||||
desc: "get existing cert",
|
||||
certID: cert.CertSerial,
|
||||
token: token,
|
||||
err: nil,
|
||||
response: sub1,
|
||||
desc: "get existing cert",
|
||||
certID: validID,
|
||||
token: token,
|
||||
cRes: c,
|
||||
},
|
||||
{
|
||||
desc: "get non-existent cert",
|
||||
certID: "43",
|
||||
token: token,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
response: sdk.Subscription{},
|
||||
desc: "get non-existent cert",
|
||||
certID: "43",
|
||||
token: token,
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
svcerr: errors.Wrap(svcerr.ErrNotFound, repoerr.ErrNotFound),
|
||||
},
|
||||
{
|
||||
desc: "get cert with invalid token",
|
||||
certID: cert.CertSerial,
|
||||
token: "",
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
response: sdk.Subscription{},
|
||||
desc: "get cert with invalid token",
|
||||
certID: validID,
|
||||
token: "",
|
||||
cRes: c,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
svcCall := svc.On("ViewCert", mock.Anything, tc.token, tc.certID).Return(tc.cRes, tc.svcerr)
|
||||
|
||||
cert, err := mgsdk.ViewCert(tc.certID, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewCertByThing(t *testing.T) {
|
||||
ts, auth, trepo, err := setupCerts()
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
|
||||
ts, svc := setupCerts()
|
||||
defer ts.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
@@ -235,127 +218,121 @@ func TestViewCertByThing(t *testing.T) {
|
||||
|
||||
mgsdk := sdk.NewSDK(sdkConf)
|
||||
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
|
||||
_, err = mgsdk.IssueCert(thingID, "10h", token)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
thingID string
|
||||
token string
|
||||
err errors.SDKError
|
||||
response sdk.Subscription
|
||||
}{
|
||||
{
|
||||
desc: "get existing cert",
|
||||
thingID: thingID,
|
||||
token: token,
|
||||
err: nil,
|
||||
response: sub1,
|
||||
},
|
||||
{
|
||||
desc: "get non-existent cert",
|
||||
thingID: "43",
|
||||
token: token,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, repoerr.ErrNotFound), http.StatusNotFound),
|
||||
response: sdk.Subscription{},
|
||||
},
|
||||
{
|
||||
desc: "get cert with invalid token",
|
||||
thingID: thingID,
|
||||
token: "",
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
response: sdk.Subscription{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
cert, err := mgsdk.ViewCertByThing(tc.thingID, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeCert(t *testing.T) {
|
||||
ts, auth, trepo, err := setupCerts()
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
|
||||
defer ts.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
CertsURL: ts.URL,
|
||||
MsgContentType: contentType,
|
||||
TLSVerification: false,
|
||||
}
|
||||
|
||||
mgsdk := sdk.NewSDK(sdkConf)
|
||||
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil)
|
||||
_, err = mgsdk.IssueCert(thingID, "10h", validToken)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err))
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
thingID string
|
||||
token string
|
||||
page certs.Page
|
||||
err errors.SDKError
|
||||
viewerr errors.SDKError
|
||||
svcerr error
|
||||
}{
|
||||
{
|
||||
desc: "revoke cert with invalid token",
|
||||
desc: "get existing cert",
|
||||
thingID: thingID,
|
||||
token: authmocks.InvalidValue,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication), http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "revoke non-existing cert",
|
||||
thingID: "2",
|
||||
token: token,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
page: certs.Page{Certs: []certs.Cert{c}},
|
||||
},
|
||||
{
|
||||
desc: "revoke cert with empty token",
|
||||
desc: "get non-existent cert",
|
||||
thingID: "43",
|
||||
token: token,
|
||||
page: certs.Page{Certs: []certs.Cert{}},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, repoerr.ErrNotFound), http.StatusNotFound),
|
||||
svcerr: errors.Wrap(svcerr.ErrNotFound, repoerr.ErrNotFound),
|
||||
viewerr: errors.NewSDKError(svcerr.ErrViewEntity),
|
||||
},
|
||||
{
|
||||
desc: "get cert with invalid token",
|
||||
thingID: thingID,
|
||||
token: "",
|
||||
page: certs.Page{Certs: []certs.Cert{}},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "revoke existing cert",
|
||||
thingID: thingID,
|
||||
token: token,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke deleted cert",
|
||||
thingID: thingID,
|
||||
token: token,
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
|
||||
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
|
||||
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, nil)
|
||||
svcCall := svc.On("ListSerials", mock.Anything, tc.token, tc.thingID, tc.page.Offset, mock.Anything).Return(tc.page, tc.svcerr)
|
||||
svcCall1 := svc.On("ViewCertByThing", mock.Anything, tc.thingID, tc.token).Return(tc.page, tc.viewerr)
|
||||
|
||||
cert, err := mgsdk.ViewCertByThing(tc.thingID, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc))
|
||||
}
|
||||
svcCall.Unset()
|
||||
svcCall1.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeCert(t *testing.T) {
|
||||
ts, svc := setupCerts()
|
||||
defer ts.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
CertsURL: ts.URL,
|
||||
MsgContentType: contentType,
|
||||
TLSVerification: false,
|
||||
}
|
||||
|
||||
mgsdk := sdk.NewSDK(sdkConf)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
page certs.Page
|
||||
thingID string
|
||||
token string
|
||||
svcResponse certs.Revoke
|
||||
err errors.SDKError
|
||||
svcerr error
|
||||
}{
|
||||
{
|
||||
desc: "revoke cert with invalid token",
|
||||
thingID: thingID,
|
||||
token: authmocks.InvalidValue,
|
||||
svcResponse: certs.Revoke{RevocationTime: time.Now()},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication), http.StatusUnauthorized),
|
||||
svcerr: errors.Wrap(svcerr.ErrAuthentication, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "revoke non-existing cert",
|
||||
thingID: "2",
|
||||
token: token,
|
||||
svcResponse: certs.Revoke{RevocationTime: time.Now()},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
svcerr: errors.Wrap(certs.ErrFailedCertRevocation, svcerr.ErrNotFound),
|
||||
},
|
||||
{
|
||||
desc: "revoke cert with empty token",
|
||||
thingID: thingID,
|
||||
token: "",
|
||||
svcResponse: certs.Revoke{RevocationTime: time.Now()},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized),
|
||||
svcerr: errors.Wrap(svcerr.ErrAuthentication, apiutil.ErrBearerToken),
|
||||
},
|
||||
{
|
||||
desc: "revoke existing cert",
|
||||
thingID: thingID,
|
||||
token: token,
|
||||
svcResponse: certs.Revoke{RevocationTime: time.Now()},
|
||||
},
|
||||
{
|
||||
desc: "revoke deleted cert",
|
||||
thingID: thingID,
|
||||
token: token,
|
||||
svcResponse: certs.Revoke{RevocationTime: time.Now()},
|
||||
err: errors.NewSDKErrorWithStatus(errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound), http.StatusNotFound),
|
||||
svcerr: errors.Wrap(certs.ErrFailedToRemoveCertFromDB, svcerr.ErrNotFound),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("RevokeCert", mock.Anything, tc.token, tc.thingID).Return(tc.svcResponse, tc.svcerr)
|
||||
|
||||
response, err := mgsdk.RevokeCert(tc.thingID, tc.token)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, response, fmt.Sprintf("%s: got empty revocation time", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
sdk "github.com/absmach/magistrala/pkg/sdk/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHealth(t *testing.T) {
|
||||
@@ -23,8 +22,7 @@ func TestHealth(t *testing.T) {
|
||||
auth.Test(t)
|
||||
defer usclsv.Close()
|
||||
|
||||
CertTs, _, _, err := setupCerts()
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err))
|
||||
CertTs, _ := setupCerts()
|
||||
defer CertTs.Close()
|
||||
|
||||
sdkConf := sdk.Config{
|
||||
|
||||
Reference in New Issue
Block a user