NOISSUE - Remove SuperMQ duplicates (#23)

* Update docker-compose to use SuperMQ

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Remove duplicate services

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Update Bootstrap

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Update other services to use SMQ

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Switch config prefix to SMQ

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Remove leftovers

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Remove duplicate interface definitions

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Remove unused actions

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Remove unused API docs

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Resolve linter comments

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

* Fix provision

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>

---------

Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
Dušan Borovčanin
2024-12-31 11:04:17 +01:00
committed by GitHub
parent 57c3ecb175
commit 3bbb25bd64
699 changed files with 4836 additions and 130238 deletions
-209
View File
@@ -1,209 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil
import "github.com/absmach/magistrala/pkg/errors"
// Errors defined in this file are used by the LoggingErrorEncoder decorator
// to distinguish and log API request validation errors and avoid that service
// errors are logged twice.
var (
// ErrValidation indicates that an error was returned by the API.
ErrValidation = errors.New("something went wrong with the request")
// ErrBearerToken indicates missing or invalid bearer user token.
ErrBearerToken = errors.New("missing or invalid bearer user token")
// ErrBearerKey indicates missing or invalid bearer entity key.
ErrBearerKey = errors.New("missing or invalid bearer entity key")
// ErrMissingID indicates missing entity ID.
ErrMissingID = errors.New("missing entity id")
// ErrInvalidAuthKey indicates invalid auth key.
ErrInvalidAuthKey = errors.New("invalid auth key")
// ErrInvalidIDFormat indicates an invalid ID format.
ErrInvalidIDFormat = errors.New("invalid id format provided")
// ErrNameSize indicates that name size exceeds the max.
ErrNameSize = errors.New("invalid name size")
// ErrEmailSize indicates that email size exceeds the max.
ErrEmailSize = errors.New("invalid email size")
// ErrInvalidRole indicates that an invalid role.
ErrInvalidRole = errors.New("invalid client role")
// ErrLimitSize indicates that an invalid limit.
ErrLimitSize = errors.New("invalid limit size")
// ErrOffsetSize indicates an invalid offset.
ErrOffsetSize = errors.New("invalid offset size")
// ErrInvalidOrder indicates an invalid list order.
ErrInvalidOrder = errors.New("invalid list order provided")
// ErrInvalidDirection indicates an invalid list direction.
ErrInvalidDirection = errors.New("invalid list direction provided")
// ErrInvalidMemberKind indicates an invalid member kind.
ErrInvalidMemberKind = errors.New("invalid member kind")
// ErrEmptyList indicates that entity data is empty.
ErrEmptyList = errors.New("empty list provided")
// ErrMalformedPolicy indicates that policies are malformed.
ErrMalformedPolicy = errors.New("malformed policy")
// ErrMissingPolicySub indicates that policies are subject.
ErrMissingPolicySub = errors.New("malformed policy subject")
// ErrMissingPolicyObj indicates missing policies object.
ErrMissingPolicyObj = errors.New("malformed policy object")
// ErrMalformedPolicyAct indicates missing policies action.
ErrMalformedPolicyAct = errors.New("malformed policy action")
// ErrMissingPolicyEntityType indicates missing policies entity type.
ErrMissingPolicyEntityType = errors.New("missing policy entity type")
// ErrMalformedPolicyPer indicates missing policies relation.
ErrMalformedPolicyPer = errors.New("malformed policy permission")
// ErrMissingCertData indicates missing cert data (ttl).
ErrMissingCertData = errors.New("missing certificate data")
// ErrInvalidCertData indicates invalid cert data (ttl).
ErrInvalidCertData = errors.New("invalid certificate data")
// ErrInvalidTopic indicates an invalid subscription topic.
ErrInvalidTopic = errors.New("invalid Subscription topic")
// ErrInvalidContact indicates an invalid subscription contract.
ErrInvalidContact = errors.New("invalid Subscription contact")
// ErrMissingEmail indicates missing email.
ErrMissingEmail = errors.New("missing email")
// ErrInvalidEmail indicates missing email.
ErrInvalidEmail = errors.New("invalid email")
// ErrMissingHost indicates missing host.
ErrMissingHost = errors.New("missing host")
// ErrMissingPass indicates missing password.
ErrMissingPass = errors.New("missing password")
// ErrMissingConfPass indicates missing conf password.
ErrMissingConfPass = errors.New("missing conf password")
// ErrInvalidResetPass indicates an invalid reset password.
ErrInvalidResetPass = errors.New("invalid reset password")
// ErrInvalidComparator indicates an invalid comparator.
ErrInvalidComparator = errors.New("invalid comparator")
// ErrMissingMemberType indicates missing group member type.
ErrMissingMemberType = errors.New("missing group member type")
// ErrMissingMemberKind indicates missing group member kind.
ErrMissingMemberKind = errors.New("missing group member kind")
// ErrMissingRelation indicates missing relation.
ErrMissingRelation = errors.New("missing relation")
// ErrInvalidRelation indicates an invalid relation.
ErrInvalidRelation = errors.New("invalid relation")
// ErrInvalidAPIKey indicates an invalid API key type.
ErrInvalidAPIKey = errors.New("invalid api key type")
// ErrBootstrapState indicates an invalid bootstrap state.
ErrBootstrapState = errors.New("invalid bootstrap state")
// ErrInvitationState indicates an invalid invitation state.
ErrInvitationState = errors.New("invalid invitation state")
// ErrMissingIdentity indicates missing entity Identity.
ErrMissingIdentity = errors.New("missing entity identity")
// ErrMissingSecret indicates missing secret.
ErrMissingSecret = errors.New("missing secret")
// ErrPasswordFormat indicates weak password.
ErrPasswordFormat = errors.New("password does not meet the requirements")
// ErrMissingName indicates missing identity name.
ErrMissingName = errors.New("missing identity name")
// ErrMissingName indicates missing alias.
ErrMissingAlias = errors.New("missing alias")
// ErrInvalidLevel indicates an invalid group level.
ErrInvalidLevel = errors.New("invalid group level (should be between 0 and 5)")
// ErrNotFoundParam indicates that the parameter was not found in the query.
ErrNotFoundParam = errors.New("parameter not found in the query")
// ErrInvalidQueryParams indicates invalid query parameters.
ErrInvalidQueryParams = errors.New("invalid query parameters")
// ErrInvalidVisibilityType indicates invalid visibility type.
ErrInvalidVisibilityType = errors.New("invalid visibility type")
// ErrUnsupportedContentType indicates unacceptable or lack of Content-Type.
ErrUnsupportedContentType = errors.New("unsupported content type")
// ErrRollbackTx indicates failed to rollback transaction.
ErrRollbackTx = errors.New("failed to rollback transaction")
// ErrInvalidAggregation indicates invalid aggregation value.
ErrInvalidAggregation = errors.New("invalid aggregation value")
// ErrInvalidInterval indicates invalid interval value.
ErrInvalidInterval = errors.New("invalid interval value")
// ErrMissingFrom indicates missing from value.
ErrMissingFrom = errors.New("missing from time value")
// ErrMissingTo indicates missing to value.
ErrMissingTo = errors.New("missing to time value")
// ErrEmptyMessage indicates empty message.
ErrEmptyMessage = errors.New("empty message")
// ErrMissingEntityType indicates missing entity type.
ErrMissingEntityType = errors.New("missing entity type")
// ErrInvalidEntityType indicates invalid entity type.
ErrInvalidEntityType = errors.New("invalid entity type")
// ErrInvalidTimeFormat indicates invalid time format i.e not unix time.
ErrInvalidTimeFormat = errors.New("invalid time format use unix time")
// ErrEmptySearchQuery indicates search query should not be empty.
ErrEmptySearchQuery = errors.New("search query must not be empty")
// ErrLenSearchQuery indicates search query length.
ErrLenSearchQuery = errors.New("search query must be at least 3 characters")
// ErrMissingDomainID indicates missing domainID.
ErrMissingDomainID = errors.New("missing domainID")
// ErrMissingUsername indicates missing user name.
ErrMissingUsername = errors.New("missing username")
// ErrInvalidUsername indicates missing user name.
ErrInvalidUsername = errors.New("invalid username")
// ErrMissingFirstName indicates missing first name.
ErrMissingFirstName = errors.New("missing first name")
// ErrMissingLastName indicates missing last name.
ErrMissingLastName = errors.New("missing last name")
// ErrInvalidProfilePictureURL indicates that the profile picture url is invalid.
ErrInvalidProfilePictureURL = errors.New("invalid profile picture url")
)
-10
View File
@@ -1,10 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil
// ErrorRes represents the HTTP error response body.
type ErrorRes struct {
Err string `json:"error"`
Msg string `json:"message"`
}
-37
View File
@@ -1,37 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil
import (
"net/http"
"strings"
)
// BearerPrefix represents the token prefix for Bearer authentication scheme.
const BearerPrefix = "Bearer "
// ThingPrefix represents the key prefix for Thing authentication scheme.
const ThingPrefix = "Thing "
// ExtractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned.
func ExtractBearerToken(r *http.Request) string {
token := r.Header.Get("Authorization")
if !strings.HasPrefix(token, BearerPrefix) {
return ""
}
return strings.TrimPrefix(token, BearerPrefix)
}
// ExtractThingKey returns value of the thing key. If there is no thing key - an empty value is returned.
func ExtractThingKey(r *http.Request) string {
token := r.Header.Get("Authorization")
if !strings.HasPrefix(token, ThingPrefix) {
return ""
}
return strings.TrimPrefix(token, ThingPrefix)
}
-112
View File
@@ -1,112 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil_test
import (
"net/http"
"testing"
"github.com/absmach/magistrala/pkg/apiutil"
"github.com/stretchr/testify/assert"
)
func TestExtractBearerToken(t *testing.T) {
cases := []struct {
desc string
request *http.Request
token string
}{
{
desc: "valid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"Bearer 123"},
},
},
token: "123",
},
{
desc: "invalid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"123"},
},
},
token: "",
},
{
desc: "empty bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {""},
},
},
token: "",
},
{
desc: "empty header",
request: &http.Request{
Header: map[string][]string{},
},
token: "",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
token := apiutil.ExtractBearerToken(c.request)
assert.Equal(t, c.token, token)
})
}
}
func TestExtractThingKey(t *testing.T) {
cases := []struct {
desc string
request *http.Request
token string
}{
{
desc: "valid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"Thing 123"},
},
},
token: "123",
},
{
desc: "invalid bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {"123"},
},
},
token: "",
},
{
desc: "empty bearer token",
request: &http.Request{
Header: map[string][]string{
"Authorization": {""},
},
},
token: "",
},
{
desc: "empty header",
request: &http.Request{
Header: map[string][]string{},
},
token: "",
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
token := apiutil.ExtractThingKey(c.request)
assert.Equal(t, c.token, token)
})
}
}
-123
View File
@@ -1,123 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strconv"
"github.com/absmach/magistrala/pkg/errors"
kithttp "github.com/go-kit/kit/transport/http"
)
// LoggingErrorEncoder is a go-kit error encoder logging decorator.
func LoggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder {
return func(ctx context.Context, err error, w http.ResponseWriter) {
if errors.Contains(err, ErrValidation) {
logger.Error(err.Error())
}
enc(ctx, err, w)
}
}
// ReadStringQuery reads the value of string http query parameters for a given key.
func ReadStringQuery(r *http.Request, key, def string) (string, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return "", ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
return vals[0], nil
}
// ReadMetadataQuery reads the value of json http query parameters for a given key.
func ReadMetadataQuery(r *http.Request, key string, def map[string]interface{}) (map[string]interface{}, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return nil, ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
m := make(map[string]interface{})
err := json.Unmarshal([]byte(vals[0]), &m)
if err != nil {
return nil, errors.Wrap(ErrInvalidQueryParams, err)
}
return m, nil
}
// ReadBoolQuery reads boolean query parameters in a given http request.
func ReadBoolQuery(r *http.Request, key string, def bool) (bool, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return false, ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
b, err := strconv.ParseBool(vals[0])
if err != nil {
return false, errors.Wrap(ErrInvalidQueryParams, err)
}
return b, nil
}
type number interface {
int64 | float64 | uint16 | uint64
}
// ReadNumQuery returns a numeric value.
func ReadNumQuery[N number](r *http.Request, key string, def N) (N, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return 0, ErrInvalidQueryParams
}
if len(vals) == 0 {
return def, nil
}
val := vals[0]
switch any(def).(type) {
case int64:
v, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case uint64:
v, err := strconv.ParseUint(val, 10, 64)
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case uint16:
v, err := strconv.ParseUint(val, 10, 16)
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
case float64:
v, err := strconv.ParseFloat(val, 64)
if err != nil {
return 0, errors.Wrap(ErrInvalidQueryParams, err)
}
return N(v), nil
default:
return def, nil
}
}
-364
View File
@@ -1,364 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package apiutil_test
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/apiutil"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/stretchr/testify/assert"
)
func TestReadStringQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret string
err error
}{
{
desc: "valid string query",
url: "http://localhost:8080/?key=test",
key: "key",
ret: "test",
err: nil,
},
{
desc: "empty string query",
url: "http://localhost:8080/",
key: "key",
ret: "",
err: nil,
},
{
desc: "multiple string query",
url: "http://localhost:8080/?key=test&key=random",
key: "key",
ret: "",
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadStringQuery(r, c.key, "")
assert.Equal(t, c.err, err)
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadMetadataQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret map[string]interface{}
err error
}{
{
desc: "valid metadata query",
url: "http://localhost:8080/?key={\"test\":\"test\"}",
key: "key",
ret: map[string]interface{}{"test": "test"},
err: nil,
},
{
desc: "empty metadata query",
url: "http://localhost:8080/",
key: "key",
ret: nil,
err: nil,
},
{
desc: "multiple metadata query",
url: "http://localhost:8080/?key={\"test\":\"test\"}&key={\"random\":\"random\"}",
key: "key",
ret: nil,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid metadata query",
url: "http://localhost:8080/?key=abc",
key: "key",
ret: nil,
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadMetadataQuery(r, c.key, nil)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadBoolQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
ret bool
err error
}{
{
desc: "valid bool query",
url: "http://localhost:8080/?key=true",
key: "key",
ret: true,
err: nil,
},
{
desc: "valid bool query",
url: "http://localhost:8080/?key=false",
key: "key",
ret: false,
err: nil,
},
{
desc: "invalid bool query",
url: "http://localhost:8080/?key=abc",
key: "key",
ret: false,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "empty bool query",
url: "http://localhost:8080/",
key: "key",
ret: false,
err: nil,
},
{
desc: "multiple bool query",
url: "http://localhost:8080/?key=true&key=false",
key: "key",
ret: false,
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
ret, err := apiutil.ReadBoolQuery(r, c.key, false)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestReadNumQuery(t *testing.T) {
cases := []struct {
desc string
url string
key string
numType string
ret interface{}
err error
}{
{
desc: "valid int64 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "int64",
ret: int64(123),
err: nil,
},
{
desc: "valid float64 query",
url: "http://localhost:8080/?key=1.23",
key: "key",
numType: "float64",
ret: float64(1.23),
err: nil,
},
{
desc: "valid uint64 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "uint64",
ret: uint64(123),
err: nil,
},
{
desc: "valid uint16 query",
url: "http://localhost:8080/?key=123",
key: "key",
numType: "uint16",
ret: uint16(123),
err: nil,
},
{
desc: "invalid int64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "int64",
ret: int64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid float64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "float64",
ret: float64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid uint64 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "uint64",
ret: uint64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "invalid uint16 query",
url: "http://localhost:8080/?key=abc",
key: "key",
numType: "uint16",
ret: uint16(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "empty int64 query",
url: "http://localhost:8080/",
key: "key",
numType: "int64",
ret: int64(0),
err: nil,
},
{
desc: "empty float64 query",
url: "http://localhost:8080/",
key: "key",
numType: "float64",
ret: float64(0),
err: nil,
},
{
desc: "empty uint16 query",
url: "http://localhost:8080/",
key: "key",
numType: "uint16",
ret: uint16(0),
err: nil,
},
{
desc: "empty uint64 query",
url: "http://localhost:8080/",
key: "key",
numType: "uint64",
ret: uint64(0),
err: nil,
},
{
desc: "multiple int64 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "int64",
ret: int64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple float64 query",
url: "http://localhost:8080/?key=1.23&key=4.56",
key: "key",
numType: "float64",
ret: float64(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple uint16 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "uint16",
ret: uint16(0),
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "multiple uint64 query",
url: "http://localhost:8080/?key=123&key=456",
key: "key",
numType: "uint64",
ret: uint64(0),
err: apiutil.ErrInvalidQueryParams,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
parsedURL, err := url.Parse(c.url)
assert.NoError(t, err)
r := &http.Request{URL: parsedURL}
var ret interface{}
switch c.numType {
case "int64":
ret, err = apiutil.ReadNumQuery[int64](r, c.key, 0)
case "float64":
ret, err = apiutil.ReadNumQuery[float64](r, c.key, 0)
case "uint64":
ret, err = apiutil.ReadNumQuery[uint64](r, c.key, 0)
case "uint16":
ret, err = apiutil.ReadNumQuery[uint16](r, c.key, 0)
}
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
assert.Equal(t, c.ret, ret)
})
}
}
func TestLoggingErrorEncoder(t *testing.T) {
cases := []struct {
desc string
err error
}{
{
desc: "error contains ErrValidation",
err: errors.Wrap(apiutil.ErrValidation, svcerr.ErrAuthentication),
},
{
desc: "error does not contain ErrValidation",
err: svcerr.ErrAuthentication,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
encCalled := false
encFunc := func(ctx context.Context, err error, w http.ResponseWriter) {
encCalled = true
}
errorEncoder := apiutil.LoggingErrorEncoder(mglog.NewMock(), encFunc)
errorEncoder(context.Background(), c.err, httptest.NewRecorder())
assert.True(t, encCalled)
})
}
}
-22
View File
@@ -1,22 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authn
import (
"context"
)
type Session struct {
DomainUserID string
UserID string
DomainID string
SuperAdmin bool
}
// Authn is magistrala authentication library.
//
//go:generate mockery --name Authentication --output=./mocks --filename authn.go --quiet --note "Copyright (c) Abstract Machines"
type Authentication interface {
Authenticate(ctx context.Context, token string) (Session, error)
}
-46
View File
@@ -1,46 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authsvc
import (
"context"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/auth/api/grpc/auth"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
"github.com/absmach/magistrala/pkg/grpcclient"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
type authentication struct {
authSvcClient magistrala.AuthServiceClient
}
var _ authn.Authentication = (*authentication)(nil)
func NewAuthentication(ctx context.Context, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
client, err := grpcclient.NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "auth",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, grpcclient.ErrSvcNotServing
}
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
return authentication{authSvcClient}, client, nil
}
func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
res, err := a.authSvcClient.Authenticate(ctx, &magistrala.AuthNReq{Token: token})
if err != nil {
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
return authn.Session{DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
}
-4
View File
@@ -1,4 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authn
-60
View File
@@ -1,60 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
authn "github.com/absmach/magistrala/pkg/authn"
mock "github.com/stretchr/testify/mock"
)
// Authentication is an autogenerated mock type for the Authentication type
type Authentication struct {
mock.Mock
}
// Authenticate provides a mock function with given fields: ctx, token
func (_m *Authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
ret := _m.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for Authenticate")
}
var r0 authn.Session
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (authn.Session, error)); ok {
return rf(ctx, token)
}
if rf, ok := ret.Get(0).(func(context.Context, string) authn.Session); ok {
r0 = rf(ctx, token)
} else {
r0 = ret.Get(0).(authn.Session)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewAuthentication creates a new instance of Authentication. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewAuthentication(t interface {
mock.TestingT
Cleanup(func())
}) *Authentication {
mock := &Authentication{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-60
View File
@@ -1,60 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authsvc
import (
"context"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/auth/api/grpc/auth"
"github.com/absmach/magistrala/pkg/authz"
"github.com/absmach/magistrala/pkg/errors"
"github.com/absmach/magistrala/pkg/grpcclient"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
type authorization struct {
authSvcClient magistrala.AuthServiceClient
}
var _ authz.Authorization = (*authorization)(nil)
func NewAuthorization(ctx context.Context, cfg grpcclient.Config) (authz.Authorization, grpcclient.Handler, error) {
client, err := grpcclient.NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "auth",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, grpcclient.ErrSvcNotServing
}
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
return authorization{authSvcClient}, client, nil
}
func (a authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error {
req := magistrala.AuthZReq{
Domain: pr.Domain,
SubjectType: pr.SubjectType,
SubjectKind: pr.SubjectKind,
SubjectRelation: pr.SubjectRelation,
Subject: pr.Subject,
Relation: pr.Relation,
Permission: pr.Permission,
Object: pr.Object,
ObjectType: pr.ObjectType,
}
res, err := a.authSvcClient.Authorize(ctx, &req)
if err != nil {
return errors.Wrap(errors.ErrAuthorization, err)
}
if !res.Authorized {
return errors.ErrAuthorization
}
return nil
}
-50
View File
@@ -1,50 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authz
import "context"
type PolicyReq struct {
// Domain contains the domain ID.
Domain string `json:"domain,omitempty"`
// Subject contains the subject ID or Token.
Subject string `json:"subject"`
// SubjectType contains the subject type. Supported subject types are
// platform, group, domain, thing, users.
SubjectType string `json:"subject_type"`
// SubjectKind contains the subject kind. Supported subject kinds are
// token, users, platform, things, channels, groups, domain.
SubjectKind string `json:"subject_kind"`
// SubjectRelation contains subject relations.
SubjectRelation string `json:"subject_relation,omitempty"`
// Object contains the object ID.
Object string `json:"object"`
// ObjectKind contains the object kind. Supported object kinds are
// users, platform, things, channels, groups, domain.
ObjectKind string `json:"object_kind"`
// ObjectType contains the object type. Supported object types are
// platform, group, domain, thing, users.
ObjectType string `json:"object_type"`
// Relation contains the relation. Supported relations are administrator, editor, contributor, member, guest, parent_group,group,domain.
Relation string `json:"relation,omitempty"`
// Permission contains the permission. Supported permissions are admin, delete, edit, share, view,
// membership, create, admin_only, edit_only, view_only, membership_only, ext_admin, ext_edit, ext_view.
Permission string `json:"permission,omitempty"`
}
// Authz is magistrala authorization library.
//
//go:generate mockery --name Authorization --output=./mocks --filename authz.go --quiet --note "Copyright (c) Abstract Machines"
type Authorization interface {
Authorize(ctx context.Context, pr PolicyReq) error
}
-4
View File
@@ -1,4 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authz
-50
View File
@@ -1,50 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
authz "github.com/absmach/magistrala/pkg/authz"
mock "github.com/stretchr/testify/mock"
)
// Authorization is an autogenerated mock type for the Authorization type
type Authorization struct {
mock.Mock
}
// Authorize provides a mock function with given fields: ctx, pr
func (_m *Authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for Authorize")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authz.PolicyReq) error); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewAuthorization creates a new instance of Authorization. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewAuthorization(t interface {
mock.TestingT
Cleanup(func())
}) *Authorization {
mock := &Authorization{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-87
View File
@@ -1,87 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"time"
)
const (
UnpublishedEventsCheckInterval = 1 * time.Minute
ConnCheckInterval = 100 * time.Millisecond
MaxUnpublishedEvents uint64 = 1e4
MaxEventStreamLen int64 = 1e6
)
// Event represents an event.
type Event interface {
// Encode encodes event to map.
Encode() (map[string]interface{}, error)
}
// Publisher specifies events publishing API.
//
//go:generate mockery --name Publisher --output=./mocks --filename publisher.go --quiet --note "Copyright (c) Abstract Machines"
type Publisher interface {
// Publish publishes event to stream.
Publish(ctx context.Context, event Event) error
// Close gracefully closes event publisher's connection.
Close() error
}
// EventHandler represents event handler for Subscriber.
type EventHandler interface {
// Handle handles events passed by underlying implementation.
Handle(ctx context.Context, event Event) error
}
// SubscriberConfig represents event subscriber configuration.
type SubscriberConfig struct {
Consumer string
Stream string
Handler EventHandler
}
// Subscriber specifies event subscription API.
//
//go:generate mockery --name Subscriber --output=./mocks --filename subscriber.go --quiet --note "Copyright (c) Abstract Machines"
type Subscriber interface {
// Subscribe subscribes to the event stream and consumes events.
Subscribe(ctx context.Context, cfg SubscriberConfig) error
// Close gracefully closes event subscriber's connection.
Close() error
}
// Read reads value from event map.
// If value is not of type T, returns default value.
func Read[T any](event map[string]interface{}, key string, def T) T {
val, ok := event[key].(T)
if !ok {
return def
}
return val
}
// ReadStringSlice reads string slice from event map.
// If value is not a string slice, returns empty slice.
func ReadStringSlice(event map[string]interface{}, key string) []string {
var res []string
vals, ok := event[key].([]interface{})
if !ok {
return res
}
for _, v := range vals {
if s, ok := v.(string); ok {
res = append(res, s)
}
}
return res
}
-67
View File
@@ -1,67 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
events "github.com/absmach/magistrala/pkg/events"
mock "github.com/stretchr/testify/mock"
)
// Publisher is an autogenerated mock type for the Publisher type
type Publisher struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *Publisher) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Publish provides a mock function with given fields: ctx, event
func (_m *Publisher) Publish(ctx context.Context, event events.Event) error {
ret := _m.Called(ctx, event)
if len(ret) == 0 {
panic("no return value specified for Publish")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, events.Event) error); ok {
r0 = rf(ctx, event)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewPublisher creates a new instance of Publisher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPublisher(t interface {
mock.TestingT
Cleanup(func())
}) *Publisher {
mock := &Publisher{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-67
View File
@@ -1,67 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
events "github.com/absmach/magistrala/pkg/events"
mock "github.com/stretchr/testify/mock"
)
// Subscriber is an autogenerated mock type for the Subscriber type
type Subscriber struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *Subscriber) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Subscribe provides a mock function with given fields: ctx, cfg
func (_m *Subscriber) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
ret := _m.Called(ctx, cfg)
if len(ret) == 0 {
panic("no return value specified for Subscribe")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, events.SubscriberConfig) error); ok {
r0 = rf(ctx, cfg)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewSubscriber creates a new instance of Subscriber. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewSubscriber(t interface {
mock.TestingT
Cleanup(func())
}) *Subscriber {
mock := &Subscriber{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-8
View File
@@ -1,8 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package redis contains the domain concept definitions needed to support
// Magistrala redis events source service functionality.
//
// It provides the abstraction of the redis stream and its operations.
package nats
-79
View File
@@ -1,79 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats
import (
"context"
"encoding/json"
"time"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/absmach/magistrala/pkg/messaging/nats"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)
// Max message payload size is 1MB.
var reconnectBufSize = 1024 * 1024 * int(events.MaxUnpublishedEvents)
type pubEventStore struct {
url string
conn *nats.Conn
publisher messaging.Publisher
stream string
}
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
conn, err := nats.Connect(url, nats.MaxReconnects(maxReconnects), nats.ReconnectBufSize(reconnectBufSize))
if err != nil {
return nil, err
}
js, err := jetstream.New(conn)
if err != nil {
return nil, err
}
if _, err := js.CreateStream(ctx, jsStreamConfig); err != nil {
return nil, err
}
publisher, err := broker.NewPublisher(ctx, url, broker.Prefix(eventsPrefix), broker.JSStream(js))
if err != nil {
return nil, err
}
es := &pubEventStore{
url: url,
conn: conn,
publisher: publisher,
stream: stream,
}
return es, nil
}
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
values, err := event.Encode()
if err != nil {
return err
}
values["occurred_at"] = time.Now().UnixNano()
data, err := json.Marshal(values)
if err != nil {
return err
}
record := &messaging.Message{
Payload: data,
}
return es.publisher.Publish(ctx, es.stream, record)
}
func (es *pubEventStore) Close() error {
es.conn.Close()
return es.publisher.Close()
}
-325
View File
@@ -1,325 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/nats"
"github.com/stretchr/testify/assert"
)
var (
eventsChan = make(chan map[string]interface{})
logger = mglog.NewMock()
errFailed = errors.New("failed")
numEvents = 100
)
type testEvent struct {
Data map[string]interface{}
}
func (te testEvent) Encode() (map[string]interface{}, error) {
data := make(map[string]interface{})
for k, v := range te.Data {
switch v.(type) {
case string:
data[k] = v
case float64:
data[k] = v
default:
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
data[k] = string(b)
}
}
return data, nil
}
func TestPublish(t *testing.T) {
_, err := nats.NewPublisher(context.Background(), "http://invaliurl.com", stream)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
publisher, err := nats.NewPublisher(context.Background(), natsURL, stream)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer publisher.Close()
_, err = nats.NewSubscriber(context.Background(), "http://invaliurl.com", logger)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer subcriber.Close()
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
cases := []struct {
desc string
event map[string]interface{}
err error
}{
{
desc: "publish event successfully",
err: nil,
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": "Earth",
"status": "normal",
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "publish with nil event",
err: nil,
event: nil,
},
{
desc: "publish event with invalid event location",
err: fmt.Errorf("json: unsupported type: chan int"),
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": make(chan int),
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "publish event with nested sting value",
err: nil,
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": map[string]string{
"lat": fmt.Sprintf("%f", rand.Float64()),
"lng": fmt.Sprintf("%f", rand.Float64()),
},
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
event := testEvent{Data: tc.event}
err := publisher.Publish(context.Background(), event)
switch tc.err {
case nil:
receivedEvent := <-eventsChan
val := int64(receivedEvent["occurred_at"].(float64))
if assert.WithinRange(t, time.Unix(0, val), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
delete(receivedEvent, "occurred_at")
delete(tc.event, "occurred_at")
}
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
assert.Equal(t, tc.event["status"], receivedEvent["status"])
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
default:
assert.ErrorContains(t, err, tc.err.Error())
}
})
}
}
func TestPubsub(t *testing.T) {
cases := []struct {
desc string
stream string
consumer string
err error
handler events.EventHandler
}{
{
desc: "Subscribe to a stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to the same stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with an empty consumer",
stream: "",
consumer: "",
err: nats.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with a valid consumer",
stream: "",
consumer: consumer,
err: nats.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to a valid stream with an empty consumer",
stream: fmt.Sprintf("events.%s", stream),
consumer: "",
err: nats.ErrEmptyConsumer,
handler: handler{false},
},
{
desc: "Subscribe to another stream",
stream: fmt.Sprintf("events.%s.%d", stream, 1),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to a stream with malformed handler",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{true},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
if err != nil {
assert.Equal(t, err, tc.err)
return
}
cfg := events.SubscriberConfig{
Stream: tc.stream,
Consumer: tc.consumer,
Handler: tc.handler,
}
switch err := subcriber.Subscribe(context.Background(), cfg); {
case err == nil:
assert.Nil(t, err)
default:
assert.Equal(t, err, tc.err)
}
err = subcriber.Close()
assert.Nil(t, err)
})
}
}
func TestUnavailablePublish(t *testing.T) {
publisher, err := nats.NewPublisher(context.Background(), natsURL, stream)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
err = pool.Client.PauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
spawnGoroutines(publisher, t)
time.Sleep(1 * time.Second)
err = pool.Client.UnpauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
// Wait for the events to be published.
time.Sleep(1 * time.Second)
err = publisher.Close()
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
// read all the events from the channel and assert that they are 10.
var receivedEvents []map[string]interface{}
for i := 0; i < numEvents; i++ {
event := <-eventsChan
receivedEvents = append(receivedEvents, event)
}
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
}
func generateRandomEvent() testEvent {
return testEvent{
Data: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
"location": fmt.Sprintf("%f", rand.Float64()),
"status": fmt.Sprintf("%d", rand.Intn(1000)),
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
"operation": "create",
},
}
}
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
for i := 0; i < numEvents; i++ {
go func() {
err := publisher.Publish(context.Background(), generateRandomEvent())
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
}()
}
}
type handler struct {
fail bool
}
func (h handler) Handle(_ context.Context, event events.Event) error {
if h.fail {
return errFailed
}
data, err := event.Encode()
if err != nil {
return err
}
eventsChan <- data
return nil
}
-81
View File
@@ -1,81 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats_test
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"testing"
"github.com/absmach/magistrala/pkg/events/nats"
"github.com/ory/dockertest/v3"
)
var (
natsURL string
stream = "tests.events"
consumer = "tests-consumer"
pool *dockertest.Pool
container *dockertest.Resource
)
func TestMain(m *testing.M) {
var err error
pool, err = dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err = pool.RunWithOptions(&dockertest.RunOptions{
Repository: "nats",
Tag: "2.10.9-alpine",
Cmd: []string{"-DVV", "-js"},
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
natsURL = fmt.Sprintf("nats://%s:%s", "localhost", container.GetPort("4222/tcp"))
if err := pool.Retry(func() error {
_, err = nats.NewPublisher(context.Background(), natsURL, stream)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
if err := pool.Retry(func() error {
_, err = nats.NewSubscriber(context.Background(), natsURL, logger)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
-138
View File
@@ -1,138 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"time"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/absmach/magistrala/pkg/messaging/nats"
"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
)
const maxReconnects = -1
var _ events.Subscriber = (*subEventStore)(nil)
var (
eventsPrefix = "events"
jsStreamConfig = jetstream.StreamConfig{
Name: "events",
Description: "Magistrala stream for sending and receiving messages in between Magistrala events",
Subjects: []string{"events.>"},
Retention: jetstream.LimitsPolicy,
MaxMsgsPerSubject: 1e9,
MaxAge: time.Hour * 24,
MaxMsgSize: 1024 * 1024,
Discard: jetstream.DiscardOld,
Storage: jetstream.FileStorage,
}
// ErrEmptyStream is returned when stream name is empty.
ErrEmptyStream = errors.New("stream name cannot be empty")
// ErrEmptyConsumer is returned when consumer name is empty.
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
)
type subEventStore struct {
conn *nats.Conn
pubsub messaging.PubSub
logger *slog.Logger
}
func NewSubscriber(ctx context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
conn, err := nats.Connect(url, nats.MaxReconnects(maxReconnects))
if err != nil {
return nil, err
}
js, err := jetstream.New(conn)
if err != nil {
return nil, err
}
jsStream, err := js.CreateStream(ctx, jsStreamConfig)
if err != nil {
return nil, err
}
pubsub, err := broker.NewPubSub(ctx, url, logger, broker.Stream(jsStream))
if err != nil {
return nil, err
}
return &subEventStore{
conn: conn,
pubsub: pubsub,
logger: logger,
}, nil
}
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
if cfg.Stream == "" {
return ErrEmptyStream
}
if cfg.Consumer == "" {
return ErrEmptyConsumer
}
subCfg := messaging.SubscriberConfig{
ID: cfg.Consumer,
Topic: cfg.Stream,
Handler: &eventHandler{
handler: cfg.Handler,
ctx: ctx,
logger: es.logger,
},
DeliveryPolicy: messaging.DeliverNewPolicy,
}
return es.pubsub.Subscribe(ctx, subCfg)
}
func (es *subEventStore) Close() error {
es.conn.Close()
return es.pubsub.Close()
}
type event struct {
Data map[string]interface{}
}
func (re event) Encode() (map[string]interface{}, error) {
return re.Data, nil
}
type eventHandler struct {
handler events.EventHandler
ctx context.Context
logger *slog.Logger
}
func (eh *eventHandler) Handle(msg *messaging.Message) error {
event := event{
Data: make(map[string]interface{}),
}
if err := json.Unmarshal(msg.GetPayload(), &event.Data); err != nil {
return err
}
if err := eh.handler.Handle(eh.ctx, event); err != nil {
eh.logger.Warn(fmt.Sprintf("failed to handle nats event: %s", err))
}
return nil
}
func (eh *eventHandler) Cancel() error {
return nil
}
-8
View File
@@ -1,8 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package redis contains the domain concept definitions needed to support
// Magistrala redis events source service functionality.
//
// It provides the abstraction of the redis stream and its operations.
package rabbitmq
-73
View File
@@ -1,73 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq
import (
"context"
"encoding/json"
"time"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/absmach/magistrala/pkg/messaging/rabbitmq"
amqp "github.com/rabbitmq/amqp091-go"
)
type pubEventStore struct {
conn *amqp.Connection
publisher messaging.Publisher
stream string
}
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
conn, err := amqp.Dial(url)
if err != nil {
return nil, err
}
ch, err := conn.Channel()
if err != nil {
return nil, err
}
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
return nil, err
}
publisher, err := broker.NewPublisher(url, broker.Prefix(eventsPrefix), broker.Exchange(exchangeName), broker.Channel(ch))
if err != nil {
return nil, err
}
es := &pubEventStore{
conn: conn,
publisher: publisher,
stream: stream,
}
return es, nil
}
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
values, err := event.Encode()
if err != nil {
return err
}
values["occurred_at"] = time.Now().UnixNano()
data, err := json.Marshal(values)
if err != nil {
return err
}
record := &messaging.Message{
Payload: data,
}
return es.publisher.Publish(ctx, es.stream, record)
}
func (es *pubEventStore) Close() error {
es.conn.Close()
return es.publisher.Close()
}
-326
View File
@@ -1,326 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/rabbitmq"
"github.com/stretchr/testify/assert"
)
var (
eventsChan = make(chan map[string]interface{})
logger = mglog.NewMock()
errFailed = errors.New("failed")
numEvents = 100
)
type testEvent struct {
Data map[string]interface{}
}
func (te testEvent) Encode() (map[string]interface{}, error) {
data := make(map[string]interface{})
for k, v := range te.Data {
switch v.(type) {
case string:
data[k] = v
case float64:
data[k] = v
default:
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
data[k] = string(b)
}
}
return data, nil
}
func TestPublish(t *testing.T) {
_, err := rabbitmq.NewPublisher(context.Background(), "http://invaliurl.com", stream)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer publisher.Close()
_, err = rabbitmq.NewSubscriber("http://invaliurl.com", logger)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer subcriber.Close()
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
cases := []struct {
desc string
event map[string]interface{}
err error
}{
{
desc: "publish event successfully",
err: nil,
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": "Earth",
"status": "normal",
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "publish with nil event",
err: nil,
event: nil,
},
{
desc: "publish event with invalid event location",
err: fmt.Errorf("json: unsupported type: chan int"),
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": make(chan int),
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "publish event with nested sting value",
err: nil,
event: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": "abc123",
"location": map[string]string{
"lat": fmt.Sprintf("%f", rand.Float64()),
"lng": fmt.Sprintf("%f", rand.Float64()),
},
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
event := testEvent{Data: tc.event}
err := publisher.Publish(context.Background(), event)
switch tc.err {
case nil:
receivedEvent := <-eventsChan
val := int64(receivedEvent["occurred_at"].(float64))
if assert.WithinRange(t, time.Unix(0, val), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
delete(receivedEvent, "occurred_at")
delete(tc.event, "occurred_at")
}
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
assert.Equal(t, tc.event["status"], receivedEvent["status"])
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
default:
assert.ErrorContains(t, err, tc.err.Error())
}
})
}
}
func TestPubsub(t *testing.T) {
cases := []struct {
desc string
stream string
consumer string
err error
handler events.EventHandler
}{
{
desc: "Subscribe to a stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to the same stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with an empty consumer",
stream: "",
consumer: "",
err: rabbitmq.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with a valid consumer",
stream: "",
consumer: consumer,
err: rabbitmq.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to a valid stream with an empty consumer",
stream: fmt.Sprintf("events.%s", stream),
consumer: "",
err: rabbitmq.ErrEmptyConsumer,
handler: handler{false},
},
{
desc: "Subscribe to another stream",
stream: fmt.Sprintf("events.%s.%d", stream, 1),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to a stream with malformed handler",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{true},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
if err != nil {
assert.Equal(t, err, tc.err)
return
}
cfg := events.SubscriberConfig{
Stream: tc.stream,
Consumer: tc.consumer,
Handler: tc.handler,
}
switch err := subcriber.Subscribe(context.Background(), cfg); {
case err == nil:
assert.Nil(t, err)
default:
assert.Equal(t, err, tc.err)
}
err = subcriber.Close()
assert.Nil(t, err)
})
}
}
func TestUnavailablePublish(t *testing.T) {
publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
err = pool.Client.PauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
spawnGoroutines(publisher, t)
time.Sleep(1 * time.Second)
err = pool.Client.UnpauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
// Wait for the events to be published.
time.Sleep(1 * time.Second)
err = publisher.Close()
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
// read all the events from the channel and assert that they are 10.
var receivedEvents []map[string]interface{}
for i := 0; i < numEvents; i++ {
event := <-eventsChan
receivedEvents = append(receivedEvents, event)
}
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
}
func generateRandomEvent() testEvent {
return testEvent{
Data: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
"location": fmt.Sprintf("%f", rand.Float64()),
"status": fmt.Sprintf("%d", rand.Intn(1000)),
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
"operation": "create",
},
}
}
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
for i := 0; i < numEvents; i++ {
go func() {
err := publisher.Publish(context.Background(), generateRandomEvent())
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
}()
}
}
type handler struct {
fail bool
}
func (h handler) Handle(_ context.Context, event events.Event) error {
if h.fail {
return errFailed
}
data, err := event.Encode()
if err != nil {
return err
}
eventsChan <- data
return nil
}
-79
View File
@@ -1,79 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq_test
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"testing"
"github.com/absmach/magistrala/pkg/events/rabbitmq"
"github.com/ory/dockertest/v3"
)
var (
rabbitmqURL string
stream = "tests.events"
consumer = "tests-consumer"
pool *dockertest.Pool
container *dockertest.Resource
)
func TestMain(m *testing.M) {
var err error
pool, err = dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err = pool.RunWithOptions(&dockertest.RunOptions{
Repository: "rabbitmq",
Tag: "3.12.12",
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
rabbitmqURL = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort("5672/tcp"))
if err := pool.Retry(func() error {
_, err = rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
if err := pool.Retry(func() error {
_, err = rabbitmq.NewSubscriber(rabbitmqURL, logger)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
-122
View File
@@ -1,122 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/absmach/magistrala/pkg/messaging/rabbitmq"
amqp "github.com/rabbitmq/amqp091-go"
)
var _ events.Subscriber = (*subEventStore)(nil)
var (
exchangeName = "events"
eventsPrefix = "events"
// ErrEmptyStream is returned when stream name is empty.
ErrEmptyStream = errors.New("stream name cannot be empty")
// ErrEmptyConsumer is returned when consumer name is empty.
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
)
type subEventStore struct {
conn *amqp.Connection
pubsub messaging.PubSub
logger *slog.Logger
}
func NewSubscriber(url string, logger *slog.Logger) (events.Subscriber, error) {
conn, err := amqp.Dial(url)
if err != nil {
return nil, err
}
ch, err := conn.Channel()
if err != nil {
return nil, err
}
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
return nil, err
}
pubsub, err := broker.NewPubSub(url, logger, broker.Channel(ch), broker.Exchange(exchangeName))
if err != nil {
return nil, err
}
return &subEventStore{
conn: conn,
pubsub: pubsub,
logger: logger,
}, nil
}
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
if cfg.Stream == "" {
return ErrEmptyStream
}
if cfg.Consumer == "" {
return ErrEmptyConsumer
}
subCfg := messaging.SubscriberConfig{
ID: cfg.Consumer,
Topic: cfg.Stream,
Handler: &eventHandler{
handler: cfg.Handler,
ctx: ctx,
logger: es.logger,
},
DeliveryPolicy: messaging.DeliverNewPolicy,
}
return es.pubsub.Subscribe(ctx, subCfg)
}
func (es *subEventStore) Close() error {
es.conn.Close()
return es.pubsub.Close()
}
type event struct {
Data map[string]interface{}
}
func (re event) Encode() (map[string]interface{}, error) {
return re.Data, nil
}
type eventHandler struct {
handler events.EventHandler
ctx context.Context
logger *slog.Logger
}
func (eh *eventHandler) Handle(msg *messaging.Message) error {
event := event{
Data: make(map[string]interface{}),
}
if err := json.Unmarshal(msg.GetPayload(), &event.Data); err != nil {
return err
}
if err := eh.handler.Handle(eh.ctx, event); err != nil {
eh.logger.Warn(fmt.Sprintf("failed to handle rabbitmq event: %s", err))
}
return nil
}
func (eh *eventHandler) Cancel() error {
return nil
}
-8
View File
@@ -1,8 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package redis contains the domain concept definitions needed to support
// Magistrala redis events source service functionality.
//
// It provides the abstraction of the redis stream and its operations.
package redis
-118
View File
@@ -1,118 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"encoding/json"
"sync"
"time"
"github.com/absmach/magistrala/pkg/events"
"github.com/redis/go-redis/v9"
)
type pubEventStore struct {
client *redis.Client
unpublishedEvents chan *redis.XAddArgs
stream string
mu sync.Mutex
flushPeriod time.Duration
}
func NewPublisher(ctx context.Context, url, stream string, flushPeriod time.Duration) (events.Publisher, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
es := &pubEventStore{
client: redis.NewClient(opts),
unpublishedEvents: make(chan *redis.XAddArgs, events.MaxUnpublishedEvents),
stream: eventsPrefix + stream,
flushPeriod: flushPeriod,
}
go es.flushUnpublished(ctx)
return es, nil
}
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
values, err := event.Encode()
if err != nil {
return err
}
values["occurred_at"] = time.Now().UnixNano()
data, err := json.Marshal(values)
if err != nil {
return err
}
record := &redis.XAddArgs{
Stream: es.stream,
MaxLen: events.MaxEventStreamLen,
Approx: true,
Values: map[string]interface{}{"data": string(data)},
}
switch err := es.checkConnection(ctx); err {
case nil:
return es.client.XAdd(ctx, record).Err()
default:
es.mu.Lock()
defer es.mu.Unlock()
// If the channel is full (rarely happens), drop the events.
if len(es.unpublishedEvents) == int(events.MaxUnpublishedEvents) {
return nil
}
es.unpublishedEvents <- record
return nil
}
}
// flushUnpublished periodically checks the Redis connection and publishes
// the events that were not published due to a connection error.
func (es *pubEventStore) flushUnpublished(ctx context.Context) {
defer close(es.unpublishedEvents)
ticker := time.NewTicker(es.flushPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := es.checkConnection(ctx); err == nil {
es.mu.Lock()
for i := len(es.unpublishedEvents) - 1; i >= 0; i-- {
record := <-es.unpublishedEvents
if err := es.client.XAdd(ctx, record).Err(); err != nil {
es.unpublishedEvents <- record
break
}
}
es.mu.Unlock()
}
case <-ctx.Done():
return
}
}
}
func (es *pubEventStore) Close() error {
return es.client.Close()
}
func (es *pubEventStore) checkConnection(ctx context.Context) error {
// A timeout is used to avoid blocking the main thread
ctx, cancel := context.WithTimeout(ctx, events.ConnCheckInterval)
defer cancel()
return es.client.Ping(ctx).Err()
}
-321
View File
@@ -1,321 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package redis_test
import (
"context"
"errors"
"fmt"
"math/rand"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/redis"
"github.com/stretchr/testify/assert"
)
var (
stream = "tests.events"
consumer = "test-consumer"
eventsChan = make(chan map[string]interface{})
logger = mglog.NewMock()
errFailed = errors.New("failed")
numEvents = 100
)
type testEvent struct {
Data map[string]interface{}
}
func (te testEvent) Encode() (map[string]interface{}, error) {
if te.Data == nil {
return map[string]interface{}{}, nil
}
return te.Data, nil
}
func TestPublish(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error on flushing redis: %s", err))
_, err = redis.NewPublisher(context.Background(), "http://invaliurl.com", stream, events.UnpublishedEventsCheckInterval)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
publisher, err := redis.NewPublisher(context.Background(), redisURL, stream, events.UnpublishedEventsCheckInterval)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer publisher.Close()
_, err = redis.NewSubscriber("http://invaliurl.com", logger)
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
subcriber, err := redis.NewSubscriber(redisURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
defer subcriber.Close()
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
cases := []struct {
desc string
event map[string]interface{}
err error
}{
{
desc: "publish event successfully",
err: nil,
event: map[string]interface{}{
"temperature": float64(rand.Float64()),
"humidity": float64(rand.Float64()),
"sensor_id": "abc123",
"location": "Earth",
"status": "normal",
"timestamp": float64(time.Now().UnixNano()),
"operation": "create",
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "publish with nil event",
err: nil,
event: nil,
},
{
desc: "publish event with invalid event location",
err: fmt.Errorf("json: unsupported type: chan int"),
event: map[string]interface{}{
"temperature": float64(rand.Float64()),
"humidity": float64(rand.Float64()),
"sensor_id": "abc123",
"location": make(chan int),
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": float64(time.Now().UnixNano()),
},
},
{
desc: "publish event with nested sting value",
err: nil,
event: map[string]interface{}{
"temperature": float64(rand.Float64()),
"humidity": float64(rand.Float64()),
"sensor_id": "abc123",
"location": map[string]string{
"lat": fmt.Sprintf("%f", rand.Float64()),
"lng": fmt.Sprintf("%f", rand.Float64()),
},
"status": "normal",
"timestamp": "invalid",
"operation": "create",
"occurred_at": float64(time.Now().UnixNano()),
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
event := testEvent{Data: tc.event}
err := publisher.Publish(context.Background(), event)
switch tc.err {
case nil:
receivedEvent := <-eventsChan
roa := receivedEvent["occurred_at"].(float64)
assert.Nil(t, err)
if assert.WithinRange(t, time.Unix(0, int64(roa)), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
delete(receivedEvent, "occurred_at")
delete(tc.event, "occurred_at")
}
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
assert.Equal(t, tc.event["status"], receivedEvent["status"])
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
default:
assert.ErrorContains(t, err, tc.err.Error())
}
})
}
}
func TestPubsub(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error on flushing redis: %s", err))
cases := []struct {
desc string
stream string
consumer string
err error
handler events.EventHandler
}{
{
desc: "Subscribe to a stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to the same stream",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with an empty consumer",
stream: "",
consumer: "",
err: redis.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to an empty stream with a valid consumer",
stream: "",
consumer: consumer,
err: redis.ErrEmptyStream,
handler: handler{false},
},
{
desc: "Subscribe to a valid stream with an empty consumer",
stream: fmt.Sprintf("events.%s", stream),
consumer: "",
err: redis.ErrEmptyConsumer,
handler: handler{false},
},
{
desc: "Subscribe to another stream",
stream: fmt.Sprintf("events.%s.%d", stream, 1),
consumer: consumer,
err: nil,
handler: handler{false},
},
{
desc: "Subscribe to a stream with malformed handler",
stream: fmt.Sprintf("events.%s", stream),
consumer: consumer,
err: nil,
handler: handler{true},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
subcriber, err := redis.NewSubscriber(redisURL, logger)
if err != nil {
assert.Equal(t, err, tc.err)
return
}
cfg := events.SubscriberConfig{
Stream: tc.stream,
Consumer: tc.consumer,
Handler: tc.handler,
}
switch err := subcriber.Subscribe(context.Background(), cfg); {
case err == nil:
assert.Nil(t, err)
default:
assert.Equal(t, err, tc.err)
}
err = subcriber.Close()
assert.Nil(t, err)
})
}
}
func TestUnavailablePublish(t *testing.T) {
publisher, err := redis.NewPublisher(context.Background(), redisURL, stream, time.Second)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
subcriber, err := redis.NewSubscriber(redisURL, logger)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
cfg := events.SubscriberConfig{
Stream: "events." + stream,
Consumer: consumer,
Handler: handler{},
}
err = subcriber.Subscribe(context.Background(), cfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
err = pool.Client.PauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
spawnGoroutines(publisher, t)
time.Sleep(1 * time.Second)
err = pool.Client.UnpauseContainer(container.Container.ID)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
// Wait for the events to be published.
time.Sleep(1 * time.Second)
err = publisher.Close()
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
var receivedEvents []map[string]interface{}
for i := 0; i < numEvents; i++ {
event := <-eventsChan
receivedEvents = append(receivedEvents, event)
}
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
}
func generateRandomEvent() testEvent {
return testEvent{
Data: map[string]interface{}{
"temperature": fmt.Sprintf("%f", rand.Float64()),
"humidity": fmt.Sprintf("%f", rand.Float64()),
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
"location": fmt.Sprintf("%f", rand.Float64()),
"status": fmt.Sprintf("%d", rand.Intn(1000)),
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
"operation": "create",
},
}
}
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
for i := 0; i < numEvents; i++ {
go func() {
err := publisher.Publish(context.Background(), generateRandomEvent())
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
}()
}
}
type handler struct {
fail bool
}
func (h handler) Handle(_ context.Context, event events.Event) error {
if h.fail {
return errFailed
}
data, err := event.Encode()
if err != nil {
return err
}
eventsChan <- data
return nil
}
-77
View File
@@ -1,77 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package redis_test
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"testing"
"github.com/ory/dockertest/v3"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
redisURL string
pool *dockertest.Pool
container *dockertest.Resource
)
func TestMain(m *testing.M) {
var err error
pool, err = dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err = pool.RunWithOptions(&dockertest.RunOptions{
Repository: "redis",
Tag: "7.2.4-alpine",
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
ropts, err := redis.ParseURL(redisURL)
if err != nil {
log.Fatalf("Could not parse redis URL: %s", err)
}
if err := pool.Retry(func() error {
redisClient = redis.NewClient(ropts)
return redisClient.Ping(context.Background()).Err()
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
-125
View File
@@ -1,125 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package redis
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"github.com/absmach/magistrala/pkg/events"
"github.com/redis/go-redis/v9"
)
const (
eventsPrefix = "events."
eventCount = 100
exists = "BUSYGROUP Consumer Group name already exists"
group = "magistrala"
)
var _ events.Subscriber = (*subEventStore)(nil)
var (
// ErrEmptyStream is returned when stream name is empty.
ErrEmptyStream = errors.New("stream name cannot be empty")
// ErrEmptyConsumer is returned when consumer name is empty.
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
)
type subEventStore struct {
client *redis.Client
logger *slog.Logger
}
func NewSubscriber(url string, logger *slog.Logger) (events.Subscriber, error) {
opts, err := redis.ParseURL(url)
if err != nil {
return nil, err
}
return &subEventStore{
client: redis.NewClient(opts),
logger: logger,
}, nil
}
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
if cfg.Stream == "" {
return ErrEmptyStream
}
if cfg.Consumer == "" {
return ErrEmptyConsumer
}
err := es.client.XGroupCreateMkStream(ctx, cfg.Stream, group, "$").Err()
if err != nil && err.Error() != exists {
return err
}
go func() {
for {
msgs, err := es.client.XReadGroup(ctx, &redis.XReadGroupArgs{
Group: group,
Consumer: cfg.Consumer,
Streams: []string{cfg.Stream, ">"},
Count: eventCount,
}).Result()
if err != nil {
es.logger.Warn(fmt.Sprintf("failed to read from redis stream: %s", err))
continue
}
if len(msgs) == 0 {
continue
}
es.handle(ctx, cfg.Stream, msgs[0].Messages, cfg.Handler)
}
}()
return nil
}
func (es *subEventStore) Close() error {
return es.client.Close()
}
type redisEvent struct {
Data map[string]interface{}
}
func (re redisEvent) Encode() (map[string]interface{}, error) {
return re.Data, nil
}
func (es *subEventStore) handle(ctx context.Context, stream string, msgs []redis.XMessage, h events.EventHandler) {
for _, msg := range msgs {
var data map[string]interface{}
if err := json.Unmarshal([]byte(msg.Values["data"].(string)), &data); err != nil {
es.logger.Warn(fmt.Sprintf("failed to unmarshal redis event: %s", err))
return
}
event := redisEvent{
Data: data,
}
if err := h.Handle(ctx, event); err != nil {
es.logger.Warn(fmt.Sprintf("failed to handle redis event: %s", err))
return
}
if err := es.client.XAck(ctx, stream, group, msg.ID).Err(); err != nil {
es.logger.Warn(fmt.Sprintf("failed to ack redis event: %s", err))
return
}
}
}
-41
View File
@@ -1,41 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build nats
// +build nats
package store
import (
"context"
"log"
"log/slog"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/nats"
)
// StreamAllEvents represents subject to subscribe for all the events.
const StreamAllEvents = "events.>"
func init() {
log.Println("The binary was build using nats as the events store")
}
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
pb, err := nats.NewPublisher(ctx, url, stream)
if err != nil {
return nil, err
}
return pb, nil
}
func NewSubscriber(ctx context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
pb, err := nats.NewSubscriber(ctx, url, logger)
if err != nil {
return nil, err
}
return pb, nil
}
-41
View File
@@ -1,41 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build rabbitmq
// +build rabbitmq
package store
import (
"context"
"log"
"log/slog"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/rabbitmq"
)
// StreamAllEvents represents subject to subscribe for all the events.
const StreamAllEvents = "events.#"
func init() {
log.Println("The binary was build using rabbitmq as the events store")
}
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
pb, err := rabbitmq.NewPublisher(ctx, url, stream)
if err != nil {
return nil, err
}
return pb, nil
}
func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
pb, err := rabbitmq.NewSubscriber(url, logger)
if err != nil {
return nil, err
}
return pb, nil
}
-41
View File
@@ -1,41 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !nats && !rabbitmq
// +build !nats,!rabbitmq
package store
import (
"context"
"log"
"log/slog"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/redis"
)
// StreamAllEvents represents subject to subscribe for all the events.
const StreamAllEvents = ">"
func init() {
log.Println("The binary was build using redis as the events store")
}
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
pb, err := redis.NewPublisher(ctx, url, stream, events.UnpublishedEventsCheckInterval)
if err != nil {
return nil, err
}
return pb, nil
}
func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
pb, err := redis.NewSubscriber(url, logger)
if err != nil {
return nil, err
}
return pb, nil
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package groups contains the domain concept definitions needed to support
// Magistrala groups functionality.
package groups
-17
View File
@@ -1,17 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups
import "errors"
var (
// ErrInvalidStatus indicates invalid status.
ErrInvalidStatus = errors.New("invalid groups status")
// ErrEnableGroup indicates error in enabling group.
ErrEnableGroup = errors.New("failed to enable group")
// ErrDisableGroup indicates error in disabling group.
ErrDisableGroup = errors.New("failed to disable group")
)
-133
View File
@@ -1,133 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups
import (
"context"
"time"
"github.com/absmach/magistrala/pkg/authn"
)
// MaxLevel represents the maximum group hierarchy level.
const MaxLevel = uint64(5)
// Group represents the group of Clients.
// Indicates a level in tree hierarchy. Root node is level 1.
// Path in a tree consisting of group IDs
// Paths are unique per domain.
type Group struct {
ID string `json:"id"`
Domain string `json:"domain_id,omitempty"`
Parent string `json:"parent_id,omitempty"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Level int `json:"level,omitempty"`
Path string `json:"path,omitempty"`
Children []*Group `json:"children,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status Status `json:"status"`
Permissions []string `json:"permissions,omitempty"`
}
type Member struct {
ID string `json:"id"`
Type string `json:"type"`
}
// Memberships contains page related metadata as well as list of memberships that
// belong to this page.
type MembersPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Members []Member `json:"members"`
}
// Page contains page related metadata as well as list
// of Groups that belong to the page.
type Page struct {
PageMeta
Path string
Level uint64
ParentID string
Permission string
ListPerms bool
Direction int64 // ancestors (+1) or descendants (-1)
Groups []Group
}
// Metadata represents arbitrary JSON.
type Metadata map[string]interface{}
// Repository specifies a group persistence API.
//
//go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines" --unroll-variadic=false
type Repository interface {
// Save group.
Save(ctx context.Context, g Group) (Group, error)
// Update a group.
Update(ctx context.Context, g Group) (Group, error)
// RetrieveByID retrieves group by its id.
RetrieveByID(ctx context.Context, id string) (Group, error)
// RetrieveAll retrieves all groups.
RetrieveAll(ctx context.Context, gm Page) (Page, error)
// RetrieveByIDs retrieves group by ids and query.
RetrieveByIDs(ctx context.Context, gm Page, ids ...string) (Page, error)
// ChangeStatus changes groups status to active or inactive
ChangeStatus(ctx context.Context, group Group) (Group, error)
// AssignParentGroup assigns parent group id to a given group id
AssignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error
// UnassignParentGroup unassign parent group id fr given group id
UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error
// Delete a group
Delete(ctx context.Context, groupID string) error
}
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" --unroll-variadic=false
type Service interface {
// CreateGroup creates new group.
CreateGroup(ctx context.Context, session authn.Session, kind string, g Group) (Group, error)
// UpdateGroup updates the group identified by the provided ID.
UpdateGroup(ctx context.Context, session authn.Session, g Group) (Group, error)
// ViewGroup retrieves data about the group identified by ID.
ViewGroup(ctx context.Context, session authn.Session, id string) (Group, error)
// ViewGroupPerms retrieves permissions on the group id for the given authorized token.
ViewGroupPerms(ctx context.Context, session authn.Session, id string) ([]string, error)
// ListGroups retrieves a list of groups basesd on entity type and entity id.
ListGroups(ctx context.Context, session authn.Session, memberKind, memberID string, gm Page) (Page, error)
// ListMembers retrieves everything that is assigned to a group identified by groupID.
ListMembers(ctx context.Context, session authn.Session, groupID, permission, memberKind string) (MembersPage, error)
// EnableGroup logically enables the group identified with the provided ID.
EnableGroup(ctx context.Context, session authn.Session, id string) (Group, error)
// DisableGroup logically disables the group identified with the provided ID.
DisableGroup(ctx context.Context, session authn.Session, id string) (Group, error)
// DeleteGroup delete the given group id
DeleteGroup(ctx context.Context, session authn.Session, id string) error
// Assign member to group
Assign(ctx context.Context, session authn.Session, groupID, relation, memberKind string, memberIDs ...string) (err error)
// Unassign member from group
Unassign(ctx context.Context, session authn.Session, groupID, relation, memberKind string, memberIDs ...string) (err error)
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package mocks contains mocks for testing purposes.
package mocks
-253
View File
@@ -1,253 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
groups "github.com/absmach/magistrala/pkg/groups"
mock "github.com/stretchr/testify/mock"
)
// Repository is an autogenerated mock type for the Repository type
type Repository struct {
mock.Mock
}
// AssignParentGroup provides a mock function with given fields: ctx, parentGroupID, groupIDs
func (_m *Repository) AssignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error {
ret := _m.Called(ctx, parentGroupID, groupIDs)
if len(ret) == 0 {
panic("no return value specified for AssignParentGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok {
r0 = rf(ctx, parentGroupID, groupIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ChangeStatus provides a mock function with given fields: ctx, group
func (_m *Repository) ChangeStatus(ctx context.Context, group groups.Group) (groups.Group, error) {
ret := _m.Called(ctx, group)
if len(ret) == 0 {
panic("no return value specified for ChangeStatus")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
return rf(ctx, group)
}
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
r0 = rf(ctx, group)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
r1 = rf(ctx, group)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Delete provides a mock function with given fields: ctx, groupID
func (_m *Repository) Delete(ctx context.Context, groupID string) error {
ret := _m.Called(ctx, groupID)
if len(ret) == 0 {
panic("no return value specified for Delete")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, groupID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RetrieveAll provides a mock function with given fields: ctx, gm
func (_m *Repository) RetrieveAll(ctx context.Context, gm groups.Page) (groups.Page, error) {
ret := _m.Called(ctx, gm)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 groups.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, groups.Page) (groups.Page, error)); ok {
return rf(ctx, gm)
}
if rf, ok := ret.Get(0).(func(context.Context, groups.Page) groups.Page); ok {
r0 = rf(ctx, gm)
} else {
r0 = ret.Get(0).(groups.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, groups.Page) error); ok {
r1 = rf(ctx, gm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveByID provides a mock function with given fields: ctx, id
func (_m *Repository) RetrieveByID(ctx context.Context, id string) (groups.Group, error) {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for RetrieveByID")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (groups.Group, error)); ok {
return rf(ctx, id)
}
if rf, ok := ret.Get(0).(func(context.Context, string) groups.Group); ok {
r0 = rf(ctx, id)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveByIDs provides a mock function with given fields: ctx, gm, ids
func (_m *Repository) RetrieveByIDs(ctx context.Context, gm groups.Page, ids ...string) (groups.Page, error) {
ret := _m.Called(ctx, gm, ids)
if len(ret) == 0 {
panic("no return value specified for RetrieveByIDs")
}
var r0 groups.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, groups.Page, ...string) (groups.Page, error)); ok {
return rf(ctx, gm, ids...)
}
if rf, ok := ret.Get(0).(func(context.Context, groups.Page, ...string) groups.Page); ok {
r0 = rf(ctx, gm, ids...)
} else {
r0 = ret.Get(0).(groups.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, groups.Page, ...string) error); ok {
r1 = rf(ctx, gm, ids...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Save provides a mock function with given fields: ctx, g
func (_m *Repository) Save(ctx context.Context, g groups.Group) (groups.Group, error) {
ret := _m.Called(ctx, g)
if len(ret) == 0 {
panic("no return value specified for Save")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
return rf(ctx, g)
}
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
r0 = rf(ctx, g)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
r1 = rf(ctx, g)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UnassignParentGroup provides a mock function with given fields: ctx, parentGroupID, groupIDs
func (_m *Repository) UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error {
ret := _m.Called(ctx, parentGroupID, groupIDs)
if len(ret) == 0 {
panic("no return value specified for UnassignParentGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok {
r0 = rf(ctx, parentGroupID, groupIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// Update provides a mock function with given fields: ctx, g
func (_m *Repository) Update(ctx context.Context, g groups.Group) (groups.Group, error) {
ret := _m.Called(ctx, g)
if len(ret) == 0 {
panic("no return value specified for Update")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
return rf(ctx, g)
}
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
r0 = rf(ctx, g)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
r1 = rf(ctx, g)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewRepository(t interface {
mock.TestingT
Cleanup(func())
}) *Repository {
mock := &Repository{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-314
View File
@@ -1,314 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
authn "github.com/absmach/magistrala/pkg/authn"
groups "github.com/absmach/magistrala/pkg/groups"
mock "github.com/stretchr/testify/mock"
)
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
// Assign provides a mock function with given fields: ctx, session, groupID, relation, memberKind, memberIDs
func (_m *Service) Assign(ctx context.Context, session authn.Session, groupID string, relation string, memberKind string, memberIDs ...string) error {
ret := _m.Called(ctx, session, groupID, relation, memberKind, memberIDs)
if len(ret) == 0 {
panic("no return value specified for Assign")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, ...string) error); ok {
r0 = rf(ctx, session, groupID, relation, memberKind, memberIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreateGroup provides a mock function with given fields: ctx, session, kind, g
func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, kind string, g groups.Group) (groups.Group, error) {
ret := _m.Called(ctx, session, kind, g)
if len(ret) == 0 {
panic("no return value specified for CreateGroup")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.Group) (groups.Group, error)); ok {
return rf(ctx, session, kind, g)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.Group) groups.Group); ok {
r0 = rf(ctx, session, kind, g)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, groups.Group) error); ok {
r1 = rf(ctx, session, kind, g)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeleteGroup provides a mock function with given fields: ctx, session, id
func (_m *Service) DeleteGroup(ctx context.Context, session authn.Session, id string) error {
ret := _m.Called(ctx, session, id)
if len(ret) == 0 {
panic("no return value specified for DeleteGroup")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = rf(ctx, session, id)
} else {
r0 = ret.Error(0)
}
return r0
}
// DisableGroup provides a mock function with given fields: ctx, session, id
func (_m *Service) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
ret := _m.Called(ctx, session, id)
if len(ret) == 0 {
panic("no return value specified for DisableGroup")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
return rf(ctx, session, id)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
r0 = rf(ctx, session, id)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
r1 = rf(ctx, session, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// EnableGroup provides a mock function with given fields: ctx, session, id
func (_m *Service) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
ret := _m.Called(ctx, session, id)
if len(ret) == 0 {
panic("no return value specified for EnableGroup")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
return rf(ctx, session, id)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
r0 = rf(ctx, session, id)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
r1 = rf(ctx, session, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListGroups provides a mock function with given fields: ctx, session, memberKind, memberID, gm
func (_m *Service) ListGroups(ctx context.Context, session authn.Session, memberKind string, memberID string, gm groups.Page) (groups.Page, error) {
ret := _m.Called(ctx, session, memberKind, memberID, gm)
if len(ret) == 0 {
panic("no return value specified for ListGroups")
}
var r0 groups.Page
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, groups.Page) (groups.Page, error)); ok {
return rf(ctx, session, memberKind, memberID, gm)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, groups.Page) groups.Page); ok {
r0 = rf(ctx, session, memberKind, memberID, gm)
} else {
r0 = ret.Get(0).(groups.Page)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, groups.Page) error); ok {
r1 = rf(ctx, session, memberKind, memberID, gm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListMembers provides a mock function with given fields: ctx, session, groupID, permission, memberKind
func (_m *Service) ListMembers(ctx context.Context, session authn.Session, groupID string, permission string, memberKind string) (groups.MembersPage, error) {
ret := _m.Called(ctx, session, groupID, permission, memberKind)
if len(ret) == 0 {
panic("no return value specified for ListMembers")
}
var r0 groups.MembersPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) (groups.MembersPage, error)); ok {
return rf(ctx, session, groupID, permission, memberKind)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) groups.MembersPage); ok {
r0 = rf(ctx, session, groupID, permission, memberKind)
} else {
r0 = ret.Get(0).(groups.MembersPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, string) error); ok {
r1 = rf(ctx, session, groupID, permission, memberKind)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Unassign provides a mock function with given fields: ctx, session, groupID, relation, memberKind, memberIDs
func (_m *Service) Unassign(ctx context.Context, session authn.Session, groupID string, relation string, memberKind string, memberIDs ...string) error {
ret := _m.Called(ctx, session, groupID, relation, memberKind, memberIDs)
if len(ret) == 0 {
panic("no return value specified for Unassign")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, ...string) error); ok {
r0 = rf(ctx, session, groupID, relation, memberKind, memberIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateGroup provides a mock function with given fields: ctx, session, g
func (_m *Service) UpdateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) {
ret := _m.Called(ctx, session, g)
if len(ret) == 0 {
panic("no return value specified for UpdateGroup")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) (groups.Group, error)); ok {
return rf(ctx, session, g)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) groups.Group); ok {
r0 = rf(ctx, session, g)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, groups.Group) error); ok {
r1 = rf(ctx, session, g)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ViewGroup provides a mock function with given fields: ctx, session, id
func (_m *Service) ViewGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
ret := _m.Called(ctx, session, id)
if len(ret) == 0 {
panic("no return value specified for ViewGroup")
}
var r0 groups.Group
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
return rf(ctx, session, id)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
r0 = rf(ctx, session, id)
} else {
r0 = ret.Get(0).(groups.Group)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
r1 = rf(ctx, session, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ViewGroupPerms provides a mock function with given fields: ctx, session, id
func (_m *Service) ViewGroupPerms(ctx context.Context, session authn.Session, id string) ([]string, error) {
ret := _m.Called(ctx, session, id)
if len(ret) == 0 {
panic("no return value specified for ViewGroupPerms")
}
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) ([]string, error)); ok {
return rf(ctx, session, id)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) []string); ok {
r0 = rf(ctx, session, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
r1 = rf(ctx, session, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-17
View File
@@ -1,17 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups
// PageMeta contains page metadata that helps navigation.
type PageMeta struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Name string `json:"name,omitempty"`
ID string `json:"id,omitempty"`
DomainID string `json:"domain_id,omitempty"`
Tag string `json:"tag,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Status Status `json:"status,omitempty"`
}
-83
View File
@@ -1,83 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups
import (
"encoding/json"
"strings"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
)
// Status represents User status.
type Status uint8
// Possible User status values.
const (
// EnabledStatus represents enabled User.
EnabledStatus Status = iota
// DisabledStatus represents disabled User.
DisabledStatus
// DeletedStatus represents a user that will be deleted.
DeletedStatus
// AllStatus is used for querying purposes to list users irrespective
// of their status - both enabled and disabled. It is never stored in the
// database as the actual User status and should always be the largest
// value in this enumeration.
AllStatus
)
// String representation of the possible status values.
const (
Disabled = "disabled"
Enabled = "enabled"
Deleted = "deleted"
All = "all"
Unknown = "unknown"
)
// String converts user/group status to string literal.
func (s Status) String() string {
switch s {
case DisabledStatus:
return Disabled
case EnabledStatus:
return Enabled
case DeletedStatus:
return Deleted
case AllStatus:
return All
default:
return Unknown
}
}
// ToStatus converts string value to a valid User/Group status.
func ToStatus(status string) (Status, error) {
switch status {
case "", Enabled:
return EnabledStatus, nil
case Disabled:
return DisabledStatus, nil
case Deleted:
return DeletedStatus, nil
case All:
return AllStatus, nil
}
return Status(0), svcerr.ErrInvalidStatus
}
// Custom Marshaller for Uesr/Groups.
func (s Status) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String())
}
// Custom Unmarshaler for User/Groups.
func (s *Status) UnmarshalJSON(data []byte) error {
str := strings.Trim(string(data), "\"")
val, err := ToStatus(str)
*s = val
return err
}
-80
View File
@@ -1,80 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpcclient
import (
"context"
"github.com/absmach/magistrala"
domainsgrpc "github.com/absmach/magistrala/auth/api/grpc/domains"
tokengrpc "github.com/absmach/magistrala/auth/api/grpc/token"
thingsauth "github.com/absmach/magistrala/things/api/grpc"
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
)
// SetupTokenClient loads auth services token gRPC configuration and creates new Token services gRPC client.
//
// For example:
//
// tokenClient, tokenHandler, err := grpcclient.SetupTokenClient(ctx, grpcclient.Config{}).
func SetupTokenClient(ctx context.Context, cfg Config) (magistrala.TokenServiceClient, Handler, error) {
client, err := NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "auth",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, ErrSvcNotServing
}
return tokengrpc.NewTokenClient(client.Connection(), cfg.Timeout), client, nil
}
// SetupDomiansClient loads domains gRPC configuration and creates a new domains gRPC client.
//
// For example:
//
// domainsClient, domainsHandler, err := grpcclient.SetupDomainsClient(ctx, grpcclient.Config{}).
func SetupDomainsClient(ctx context.Context, cfg Config) (magistrala.DomainsServiceClient, Handler, error) {
client, err := NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "auth",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, ErrSvcNotServing
}
return domainsgrpc.NewDomainsClient(client.Connection(), cfg.Timeout), client, nil
}
// SetupThingsClient loads things gRPC configuration and creates new things gRPC client.
//
// For example:
//
// thingClient, thingHandler, err := grpcclient.SetupThings(ctx, grpcclient.Config{}).
func SetupThingsClient(ctx context.Context, cfg Config) (magistrala.ThingsServiceClient, Handler, error) {
client, err := NewHandler(cfg)
if err != nil {
return nil, nil, err
}
health := grpchealth.NewHealthClient(client.Connection())
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
Service: "things",
})
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
return nil, nil, ErrSvcNotServing
}
return thingsauth.NewClient(client.Connection(), cfg.Timeout), client, nil
}
-179
View File
@@ -1,179 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpcclient_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/absmach/magistrala"
domainsgrpcapi "github.com/absmach/magistrala/auth/api/grpc/domains"
tokengrpcapi "github.com/absmach/magistrala/auth/api/grpc/token"
"github.com/absmach/magistrala/auth/mocks"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/errors"
"github.com/absmach/magistrala/pkg/grpcclient"
"github.com/absmach/magistrala/pkg/server"
grpcserver "github.com/absmach/magistrala/pkg/server/grpc"
thingsgrpcapi "github.com/absmach/magistrala/things/api/grpc"
thmocks "github.com/absmach/magistrala/things/mocks"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
func TestSetupToken(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
registerAuthServiceServer := func(srv *grpc.Server) {
magistrala.RegisterTokenServiceServer(srv, tokengrpcapi.NewTokenServer(new(mocks.Service)))
}
gs := grpcserver.NewServer(ctx, cancel, "auth", server.Config{Port: "12345"}, registerAuthServiceServer, mglog.NewMock())
go func() {
err := gs.Start()
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error creating server %s"`, err))
}()
defer func() {
err := gs.Stop()
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error stopping server %s"`, err))
}()
cases := []struct {
desc string
config grpcclient.Config
err error
}{
{
desc: "successful",
config: grpcclient.Config{
URL: "localhost:12345",
Timeout: time.Second,
},
err: nil,
},
{
desc: "failed with empty URL",
config: grpcclient.Config{
URL: "",
Timeout: time.Second,
},
err: errors.New("service is not serving"),
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
client, handler, err := grpcclient.SetupTokenClient(context.Background(), c.config)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
if err == nil {
assert.NotNil(t, client)
assert.NotNil(t, handler)
}
})
}
}
func TestSetupThingsClient(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
registerThingsServiceServer := func(srv *grpc.Server) {
magistrala.RegisterThingsServiceServer(srv, thingsgrpcapi.NewServer(new(thmocks.Service)))
}
gs := grpcserver.NewServer(ctx, cancel, "things", server.Config{Port: "12345"}, registerThingsServiceServer, mglog.NewMock())
go func() {
err := gs.Start()
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error creating server %s"`, err))
}()
defer func() {
err := gs.Stop()
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error stopping server %s"`, err))
}()
cases := []struct {
desc string
config grpcclient.Config
err error
}{
{
desc: "successful",
config: grpcclient.Config{
URL: "localhost:12345",
Timeout: time.Second,
},
err: nil,
},
{
desc: "failed with empty URL",
config: grpcclient.Config{
URL: "",
Timeout: time.Second,
},
err: errors.New("service is not serving"),
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
client, handler, err := grpcclient.SetupThingsClient(context.Background(), c.config)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
if err == nil {
assert.NotNil(t, client)
assert.NotNil(t, handler)
}
})
}
}
func TestSetupDomainsClient(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
registerDomainsServiceServer := func(srv *grpc.Server) {
magistrala.RegisterDomainsServiceServer(srv, domainsgrpcapi.NewDomainsServer(new(mocks.Service)))
}
gs := grpcserver.NewServer(ctx, cancel, "auth", server.Config{Port: "12345"}, registerDomainsServiceServer, mglog.NewMock())
go func() {
err := gs.Start()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating server %s", err))
}()
defer func() {
err := gs.Stop()
assert.Nil(t, err, fmt.Sprintf("Unexpected error stopping server %s", err))
}()
cases := []struct {
desc string
config grpcclient.Config
err error
}{
{
desc: "successfully",
config: grpcclient.Config{
URL: "localhost:12345",
Timeout: time.Second,
},
err: nil,
},
{
desc: "failed with empty URL",
config: grpcclient.Config{
URL: "",
Timeout: time.Second,
},
err: errors.New("service is not serving"),
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
client, handler, err := grpcclient.SetupDomainsClient(context.Background(), c.config)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
if err == nil {
assert.NotNil(t, client)
assert.NotNil(t, handler)
}
})
}
}
-153
View File
@@ -1,153 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpcclient
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"time"
"github.com/absmach/magistrala/pkg/errors"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
type security int
const (
withoutTLS security = iota
withTLS
withmTLS
)
const buffSize = 10 * 1024 * 1024
var (
errGrpcConnect = errors.New("failed to connect to grpc server")
errGrpcClose = errors.New("failed to close grpc connection")
ErrSvcNotServing = errors.New("service is not serving")
)
type Config struct {
URL string `env:"URL" envDefault:""`
Timeout time.Duration `env:"TIMEOUT" envDefault:"1s"`
ClientCert string `env:"CLIENT_CERT" envDefault:""`
ClientKey string `env:"CLIENT_KEY" envDefault:""`
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
}
// Handler is used to handle gRPC connection.
type Handler interface {
// Close closes gRPC connection.
Close() error
// Secure is used for pretty printing TLS info.
Secure() string
// Connection returns the gRPC connection.
Connection() *grpc.ClientConn
}
type client struct {
*grpc.ClientConn
cfg Config
secure security
}
var _ Handler = (*client)(nil)
func NewHandler(cfg Config) (Handler, error) {
conn, secure, err := connect(cfg)
if err != nil {
return nil, err
}
return &client{
ClientConn: conn,
cfg: cfg,
secure: secure,
}, nil
}
func (c *client) Close() error {
if err := c.ClientConn.Close(); err != nil {
return errors.Wrap(errGrpcClose, err)
}
return nil
}
func (c *client) Connection() *grpc.ClientConn {
return c.ClientConn
}
// Secure is used for pretty printing TLS info.
func (c *client) Secure() string {
switch c.secure {
case withTLS:
return "with TLS"
case withmTLS:
return "with mTLS"
case withoutTLS:
fallthrough
default:
return "without TLS"
}
}
// connect creates new gRPC client and connect to gRPC server.
func connect(cfg Config) (*grpc.ClientConn, security, error) {
opts := []grpc.DialOption{
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
}
secure := withoutTLS
tc := insecure.NewCredentials()
if cfg.ServerCAFile != "" {
tlsConfig := &tls.Config{}
// Loading root ca certificates file
rootCA, err := os.ReadFile(cfg.ServerCAFile)
if err != nil {
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
}
if len(rootCA) > 0 {
capool := x509.NewCertPool()
if !capool.AppendCertsFromPEM(rootCA) {
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
}
tlsConfig.RootCAs = capool
secure = withTLS
}
// Loading mtls certificates file
if cfg.ClientCert != "" || cfg.ClientKey != "" {
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
if err != nil {
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
}
tlsConfig.Certificates = []tls.Certificate{certificate}
secure = withmTLS
}
tc = credentials.NewTLS(tlsConfig)
}
opts = append(
opts, grpc.WithTransportCredentials(tc),
grpc.WithReadBufferSize(buffSize),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(buffSize/10), grpc.MaxCallSendMsgSize(buffSize/10)),
grpc.WithWriteBufferSize(buffSize),
)
conn, err := grpc.NewClient(cfg.URL, opts...)
if err != nil {
return nil, secure, errors.Wrap(errGrpcConnect, err)
}
return conn, secure, nil
}
-114
View File
@@ -1,114 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpcclient
import (
"fmt"
"testing"
"time"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestHandler(t *testing.T) {
cases := []struct {
desc string
config Config
err error
secure string
}{
{
desc: "successful without TLS",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
},
err: nil,
secure: "without TLS",
},
{
desc: "successful with TLS",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ServerCAFile: "../../docker/ssl/certs/ca.crt",
},
err: nil,
secure: "with TLS",
},
{
desc: "successful with mTLS",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ClientCert: "../../docker/ssl/certs/magistrala-server.crt",
ClientKey: "../../docker/ssl/certs/magistrala-server.key",
ServerCAFile: "../../docker/ssl/certs/ca.crt",
},
err: nil,
secure: "with mTLS",
},
{
desc: "failed with empty URL",
config: Config{
URL: "",
Timeout: time.Second,
},
secure: "without TLS",
},
{
desc: "failed with invalid server CA file",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ServerCAFile: "invalid",
},
err: errors.New("failed to load root ca file: open invalid: no such file or directory"),
},
{
desc: "failed with invalid server CA file as cert key",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ServerCAFile: "../../docker/ssl/certs/magistrala-server.key",
},
err: errors.New("failed to append root ca to tls.Config"),
},
{
desc: "failed with invalid client cert",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ClientCert: "invalid",
ClientKey: "../../docker/ssl/certs/magistrala-server.key",
ServerCAFile: "../../docker/ssl/certs/ca.crt",
},
err: errors.New("failed to client certificate and key open invalid: no such file or directory"),
},
{
desc: "failed with invalid client key",
config: Config{
URL: "localhost:8080",
Timeout: time.Second,
ClientCert: "../../docker/ssl/certs/magistrala-server.crt",
ClientKey: "invalid",
ServerCAFile: "../../docker/ssl/certs/ca.crt",
},
err: errors.New("failed to client certificate and key open invalid: no such file or directory"),
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
handler, err := NewHandler(c.config)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
if err == nil {
assert.Equal(t, c.secure, handler.Secure())
assert.NotNil(t, handler.Connection())
assert.Nil(t, handler.Close())
}
})
}
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package auth contains the domain concept definitions needed to support
// Magistrala auth functionality.
package grpcclient
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package jaeger contains the domain concept definitions needed to support
// Magistrala Jaeger tracing functionality.
package jaeger
-77
View File
@@ -1,77 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package jaeger
import (
"context"
"errors"
"net/url"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)
var (
errNoURL = errors.New("URL is empty")
errNoSvcName = errors.New("service Name is empty")
errUnsupportedTraceURLScheme = errors.New("unsupported tracing url scheme")
)
// NewProvider initializes Jaeger TraceProvider.
//
// tp, err := jaeger.NewProvider(ctx, "demo-service", "http://localhost:14268/api/traces", "2cb32911-6833-469c-9cad-4d3e93c528d8", "1.0")
func NewProvider(ctx context.Context, svcName string, jaegerUrl url.URL, instanceID string, fraction float64) (*trace.TracerProvider, error) {
if jaegerUrl == (url.URL{}) {
return nil, errNoURL
}
if svcName == "" {
return nil, errNoSvcName
}
var client otlptrace.Client
switch jaegerUrl.Scheme {
case "http":
client = otlptracehttp.NewClient(otlptracehttp.WithEndpoint(jaegerUrl.Host), otlptracehttp.WithURLPath(jaegerUrl.Path), otlptracehttp.WithInsecure())
case "https":
client = otlptracehttp.NewClient(otlptracehttp.WithEndpoint(jaegerUrl.Host), otlptracehttp.WithURLPath(jaegerUrl.Path))
default:
return nil, errUnsupportedTraceURLScheme
}
exporter, err := otlptrace.New(ctx, client)
if err != nil {
return nil, err
}
attributes := []attribute.KeyValue{
semconv.ServiceNameKey.String(svcName),
attribute.String("host.id", instanceID),
}
hostAttr, err := resource.New(ctx, resource.WithHost(), resource.WithOSDescription(), resource.WithContainer())
if err != nil {
return nil, err
}
attributes = append(attributes, hostAttr.Attributes()...)
tp := trace.NewTracerProvider(
trace.WithSampler(trace.TraceIDRatioBased(fraction)),
trace.WithBatcher(exporter),
trace.WithResource(resource.NewWithAttributes(
semconv.SchemaURL,
attributes...,
)),
)
otel.SetTracerProvider(tp)
otel.SetTextMapPropagator(propagation.TraceContext{})
return tp, nil
}
-9
View File
@@ -1,9 +0,0 @@
# Messaging
`messaging` package defines `Publisher`, `Subscriber` and an aggregate `Pubsub` interface.
`Subscriber` interface defines methods used to subscribe to a message broker such as MQTT or NATS or RabbitMQ.
`Publisher` interface defines methods used to publish messages to a message broker such as MQTT or NATS or RabbitMQ.
`Pubsub` interface is composed of `Publisher` and `Subscriber` interface and can be used to send messages to as well as to receive messages from a message broker.
-41
View File
@@ -1,41 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !rabbitmq
// +build !rabbitmq
package brokers
import (
"context"
"log"
"log/slog"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/nats"
)
// SubjectAllChannels represents subject to subscribe for all the channels.
const SubjectAllChannels = "channels.>"
func init() {
log.Println("The binary was build using Nats as the message broker")
}
func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
pb, err := nats.NewPublisher(ctx, url, opts...)
if err != nil {
return nil, err
}
return pb, nil
}
func NewPubSub(ctx context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
pb, err := nats.NewPubSub(ctx, url, logger, opts...)
if err != nil {
return nil, err
}
return pb, nil
}
-41
View File
@@ -1,41 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build rabbitmq
// +build rabbitmq
package brokers
import (
"context"
"log"
"log/slog"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
)
// SubjectAllChannels represents subject to subscribe for all the channels.
const SubjectAllChannels = "channels.#"
func init() {
log.Println("The binary was build using RabbitMQ as the message broker")
}
func NewPublisher(_ context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
pb, err := rabbitmq.NewPublisher(url, opts...)
if err != nil {
return nil, err
}
return pb, nil
}
func NewPubSub(_ context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
pb, err := rabbitmq.NewPubSub(url, logger, opts...)
if err != nil {
return nil, err
}
return pb, nil
}
@@ -1,31 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !rabbitmq
// +build !rabbitmq
package brokers
import (
"log"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/nats/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/trace"
)
// SubjectAllChannels represents subject to subscribe for all the channels.
const SubjectAllChannels = "channels.>"
func init() {
log.Println("The binary was build using Nats as the message broker")
}
func NewPublisher(cfg server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
return tracing.NewPublisher(cfg, tracer, publisher)
}
func NewPubSub(cfg server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
return tracing.NewPubSub(cfg, tracer, pubsub)
}
@@ -1,31 +0,0 @@
//go:build rabbitmq
// +build rabbitmq
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package brokers
import (
"log"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/rabbitmq/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/trace"
)
// SubjectAllChannels represents subject to subscribe for all the channels.
const SubjectAllChannels = "channels.#"
func init() {
log.Println("The binary was build using RabbitMQ as the message broker")
}
func NewPublisher(cfg server.Config, tracer trace.Tracer, pub messaging.Publisher) messaging.Publisher {
return tracing.NewPublisher(cfg, tracer, pub)
}
func NewPubSub(cfg server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
return tracing.NewPubSub(cfg, tracer, pubsub)
}
-90
View File
@@ -1,90 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !test
package handler
import (
"context"
"log/slog"
"time"
"github.com/absmach/mgate/pkg/session"
)
var _ session.Handler = (*loggingMiddleware)(nil)
type loggingMiddleware struct {
logger *slog.Logger
svc session.Handler
}
// AuthConnect implements session.Handler.
func (lm *loggingMiddleware) AuthConnect(ctx context.Context) (err error) {
defer lm.logAction("AuthConnect", nil, time.Now(), err)
return lm.svc.AuthConnect(ctx)
}
// AuthPublish implements session.Handler.
func (lm *loggingMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) (err error) {
defer lm.logAction("AuthPublish", &[]string{*topic}, time.Now(), err)
return lm.svc.AuthPublish(ctx, topic, payload)
}
// AuthSubscribe implements session.Handler.
func (lm *loggingMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) (err error) {
defer lm.logAction("AuthSubscribe", topics, time.Now(), err)
return lm.svc.AuthSubscribe(ctx, topics)
}
// Connect implements session.Handler.
func (lm *loggingMiddleware) Connect(ctx context.Context) (err error) {
defer lm.logAction("Connect", nil, time.Now(), err)
return lm.svc.Connect(ctx)
}
// Disconnect implements session.Handler.
func (lm *loggingMiddleware) Disconnect(ctx context.Context) (err error) {
defer lm.logAction("Disconnect", nil, time.Now(), err)
return lm.svc.Disconnect(ctx)
}
// Publish logs the publish request. It logs the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) (err error) {
defer lm.logAction("Publish", &[]string{*topic}, time.Now(), err)
return lm.svc.Publish(ctx, topic, payload)
}
// Subscribe implements session.Handler.
func (lm *loggingMiddleware) Subscribe(ctx context.Context, topics *[]string) (err error) {
defer lm.logAction("Subscribe", topics, time.Now(), err)
return lm.svc.Subscribe(ctx, topics)
}
// Unsubscribe implements session.Handler.
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, topics *[]string) (err error) {
defer lm.logAction("Unsubscribe", topics, time.Now(), err)
return lm.svc.Unsubscribe(ctx, topics)
}
// LoggingMiddleware adds logging facilities to the adapter.
func LoggingMiddleware(svc session.Handler, logger *slog.Logger) session.Handler {
return &loggingMiddleware{logger, svc}
}
func (lm *loggingMiddleware) logAction(action string, topics *[]string, t time.Time, err error) {
args := []any{
slog.String("duration", time.Since(t).String()),
}
if topics != nil {
args = append(args, slog.Any("topics", *topics))
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn(action+" failed", args...)
return
}
lm.logger.Info(action+" completed successfully", args...)
}
-86
View File
@@ -1,86 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !test
package handler
import (
"context"
"time"
"github.com/absmach/mgate/pkg/session"
"github.com/go-kit/kit/metrics"
)
var _ session.Handler = (*metricsMiddleware)(nil)
type metricsMiddleware struct {
counter metrics.Counter
latency metrics.Histogram
svc session.Handler
}
// MetricsMiddleware instruments adapter by tracking request count and latency.
func MetricsMiddleware(svc session.Handler, counter metrics.Counter, latency metrics.Histogram) session.Handler {
return &metricsMiddleware{
counter: counter,
latency: latency,
svc: svc,
}
}
// AuthConnect implements session.Handler.
func (mm *metricsMiddleware) AuthConnect(ctx context.Context) error {
defer func(begin time.Time) {
mm.counter.With("method", "publish").Add(1)
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.AuthConnect(ctx)
}
// AuthPublish implements session.Handler.
func (mm *metricsMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
defer func(begin time.Time) {
mm.counter.With("method", "publish").Add(1)
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.AuthPublish(ctx, topic, payload)
}
// AuthSubscribe implements session.Handler.
func (*metricsMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error {
return nil
}
// Connect implements session.Handler.
func (*metricsMiddleware) Connect(ctx context.Context) error {
return nil
}
// Disconnect implements session.Handler.
func (*metricsMiddleware) Disconnect(ctx context.Context) error {
return nil
}
// Publish instruments Publish method with metrics.
func (mm *metricsMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error {
defer func(begin time.Time) {
mm.counter.With("method", "publish").Add(1)
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Publish(ctx, topic, payload)
}
// Subscribe implements session.Handler.
func (*metricsMiddleware) Subscribe(ctx context.Context, topics *[]string) error {
return nil
}
// Unsubscribe implements session.Handler.
func (*metricsMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error {
return nil
}
-116
View File
@@ -1,116 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package handler
import (
"context"
"github.com/absmach/mgate/pkg/session"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
const (
authConnectOP = "auth_connect_op"
authPublishOP = "auth_publish_op"
authSubscribeOP = "auth_subscribe_op"
connectOP = "connect_op"
disconnectOP = "disconnect_op"
subscribeOP = "subscribe_op"
unsubscribeOP = "unsubscribe_op"
publishOP = "publish_op"
)
var _ session.Handler = (*handlerMiddleware)(nil)
type handlerMiddleware struct {
handler session.Handler
tracer trace.Tracer
}
// NewHandler creates a new session.Handler middleware with tracing.
func NewTracing(tracer trace.Tracer, handler session.Handler) session.Handler {
return &handlerMiddleware{
tracer: tracer,
handler: handler,
}
}
// AuthConnect traces auth connect operations.
func (h *handlerMiddleware) AuthConnect(ctx context.Context) error {
kvOpts := []attribute.KeyValue{}
s, ok := session.FromContext(ctx)
if ok {
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
kvOpts = append(kvOpts, attribute.String("username", s.Username))
}
ctx, span := h.tracer.Start(ctx, authConnectOP, trace.WithAttributes(kvOpts...))
defer span.End()
return h.handler.AuthConnect(ctx)
}
// AuthPublish traces auth publish operations.
func (h *handlerMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
kvOpts := []attribute.KeyValue{}
s, ok := session.FromContext(ctx)
if ok {
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
if topic != nil {
kvOpts = append(kvOpts, attribute.String("topic", *topic))
}
}
ctx, span := h.tracer.Start(ctx, authPublishOP, trace.WithAttributes(kvOpts...))
defer span.End()
return h.handler.AuthPublish(ctx, topic, payload)
}
// AuthSubscribe traces auth subscribe operations.
func (h *handlerMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error {
kvOpts := []attribute.KeyValue{}
s, ok := session.FromContext(ctx)
if ok {
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
if topics != nil {
kvOpts = append(kvOpts, attribute.StringSlice("topics", *topics))
}
}
ctx, span := h.tracer.Start(ctx, authSubscribeOP, trace.WithAttributes(kvOpts...))
defer span.End()
return h.handler.AuthSubscribe(ctx, topics)
}
// Connect traces connect operations.
func (h *handlerMiddleware) Connect(ctx context.Context) error {
ctx, span := h.tracer.Start(ctx, connectOP)
defer span.End()
return h.handler.Connect(ctx)
}
// Disconnect traces disconnect operations.
func (h *handlerMiddleware) Disconnect(ctx context.Context) error {
ctx, span := h.tracer.Start(ctx, disconnectOP)
defer span.End()
return h.handler.Disconnect(ctx)
}
// Publish traces publish operations.
func (h *handlerMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error {
ctx, span := h.tracer.Start(ctx, publishOP)
defer span.End()
return h.handler.Publish(ctx, topic, payload)
}
// Subscribe traces subscribe operations.
func (h *handlerMiddleware) Subscribe(ctx context.Context, topics *[]string) error {
ctx, span := h.tracer.Start(ctx, subscribeOP)
defer span.End()
return h.handler.Subscribe(ctx, topics)
}
// Unsubscribe traces unsubscribe operations.
func (h *handlerMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error {
ctx, span := h.tracer.Start(ctx, unsubscribeOP)
defer span.End()
return h.handler.Unsubscribe(ctx, topics)
}
-195
View File
@@ -1,195 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.27.1
// source: pkg/messaging/message.proto
package messaging
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Message represents a message emitted by the Magistrala adapters layer.
type Message struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Channel string `protobuf:"bytes,1,opt,name=channel,proto3" json:"channel,omitempty"`
Subtopic string `protobuf:"bytes,2,opt,name=subtopic,proto3" json:"subtopic,omitempty"`
Publisher string `protobuf:"bytes,3,opt,name=publisher,proto3" json:"publisher,omitempty"`
Protocol string `protobuf:"bytes,4,opt,name=protocol,proto3" json:"protocol,omitempty"`
Payload []byte `protobuf:"bytes,5,opt,name=payload,proto3" json:"payload,omitempty"`
Created int64 `protobuf:"varint,6,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds
}
func (x *Message) Reset() {
*x = Message{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_messaging_message_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message) ProtoMessage() {}
func (x *Message) ProtoReflect() protoreflect.Message {
mi := &file_pkg_messaging_message_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
func (*Message) Descriptor() ([]byte, []int) {
return file_pkg_messaging_message_proto_rawDescGZIP(), []int{0}
}
func (x *Message) GetChannel() string {
if x != nil {
return x.Channel
}
return ""
}
func (x *Message) GetSubtopic() string {
if x != nil {
return x.Subtopic
}
return ""
}
func (x *Message) GetPublisher() string {
if x != nil {
return x.Publisher
}
return ""
}
func (x *Message) GetProtocol() string {
if x != nil {
return x.Protocol
}
return ""
}
func (x *Message) GetPayload() []byte {
if x != nil {
return x.Payload
}
return nil
}
func (x *Message) GetCreated() int64 {
if x != nil {
return x.Created
}
return 0
}
var File_pkg_messaging_message_proto protoreflect.FileDescriptor
var file_pkg_messaging_message_proto_rawDesc = []byte{
0x0a, 0x1b, 0x70, 0x6b, 0x67, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x2f,
0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, 0x6d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x22, 0xad, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73,
0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x12, 0x1a,
0x0a, 0x08, 0x73, 0x75, 0x62, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
0x52, 0x08, 0x73, 0x75, 0x62, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x75,
0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70,
0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18,
0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x18,
0x0a, 0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52,
0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x6d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_messaging_message_proto_rawDescOnce sync.Once
file_pkg_messaging_message_proto_rawDescData = file_pkg_messaging_message_proto_rawDesc
)
func file_pkg_messaging_message_proto_rawDescGZIP() []byte {
file_pkg_messaging_message_proto_rawDescOnce.Do(func() {
file_pkg_messaging_message_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_messaging_message_proto_rawDescData)
})
return file_pkg_messaging_message_proto_rawDescData
}
var file_pkg_messaging_message_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_pkg_messaging_message_proto_goTypes = []any{
(*Message)(nil), // 0: messaging.Message
}
var file_pkg_messaging_message_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_messaging_message_proto_init() }
func file_pkg_messaging_message_proto_init() {
if File_pkg_messaging_message_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_messaging_message_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*Message); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_messaging_message_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_pkg_messaging_message_proto_goTypes,
DependencyIndexes: file_pkg_messaging_message_proto_depIdxs,
MessageInfos: file_pkg_messaging_message_proto_msgTypes,
}.Build()
File_pkg_messaging_message_proto = out.File
file_pkg_messaging_message_proto_rawDesc = nil
file_pkg_messaging_message_proto_goTypes = nil
file_pkg_messaging_message_proto_depIdxs = nil
}
-17
View File
@@ -1,17 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
syntax = "proto3";
package messaging;
option go_package = "./messaging";
// Message represents a message emitted by the Magistrala adapters layer.
message Message {
string channel = 1;
string subtopic = 2;
string publisher = 3;
string protocol = 4;
bytes payload = 5;
int64 created = 6; // Unix timestamp in nanoseconds
}
-103
View File
@@ -1,103 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
messaging "github.com/absmach/magistrala/pkg/messaging"
mock "github.com/stretchr/testify/mock"
)
// PubSub is an autogenerated mock type for the PubSub type
type PubSub struct {
mock.Mock
}
// Close provides a mock function with given fields:
func (_m *PubSub) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// Publish provides a mock function with given fields: ctx, topic, msg
func (_m *PubSub) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
ret := _m.Called(ctx, topic, msg)
if len(ret) == 0 {
panic("no return value specified for Publish")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, *messaging.Message) error); ok {
r0 = rf(ctx, topic, msg)
} else {
r0 = ret.Error(0)
}
return r0
}
// Subscribe provides a mock function with given fields: ctx, cfg
func (_m *PubSub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
ret := _m.Called(ctx, cfg)
if len(ret) == 0 {
panic("no return value specified for Subscribe")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, messaging.SubscriberConfig) error); ok {
r0 = rf(ctx, cfg)
} else {
r0 = ret.Error(0)
}
return r0
}
// Unsubscribe provides a mock function with given fields: ctx, id, topic
func (_m *PubSub) Unsubscribe(ctx context.Context, id string, topic string) error {
ret := _m.Called(ctx, id, topic)
if len(ret) == 0 {
panic("no return value specified for Unsubscribe")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, id, topic)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewPubSub creates a new instance of PubSub. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewPubSub(t interface {
mock.TestingT
Cleanup(func())
}) *PubSub {
mock := &PubSub{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-11
View File
@@ -1,11 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package mqtt hold the implementation of the Publisher and PubSub
// interfaces for the MQTT messaging system, the internal messaging
// broker of the Magistrala IoT platform. Due to the practical requirements
// implementation Publisher is created alongside PubSub. The reason for
// this is that Subscriber implementation of MQTT brings the burden of
// additional struct fields which are not used by Publisher. Subscriber
// is not implemented separately because PubSub can be used where Subscriber is needed.
package mqtt
-61
View File
@@ -1,61 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mqtt
import (
"context"
"errors"
"time"
"github.com/absmach/magistrala/pkg/messaging"
mqtt "github.com/eclipse/paho.mqtt.golang"
)
var errPublishTimeout = errors.New("failed to publish due to timeout reached")
var _ messaging.Publisher = (*publisher)(nil)
type publisher struct {
client mqtt.Client
timeout time.Duration
qos uint8
}
// NewPublisher returns a new MQTT message publisher.
func NewPublisher(address string, qos uint8, timeout time.Duration) (messaging.Publisher, error) {
client, err := newClient(address, "mqtt-publisher", timeout)
if err != nil {
return nil, err
}
ret := publisher{
client: client,
timeout: timeout,
qos: qos,
}
return ret, nil
}
func (pub publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
if topic == "" {
return ErrEmptyTopic
}
// Publish only the payload and not the whole message.
token := pub.client.Publish(topic, byte(pub.qos), false, msg.GetPayload())
if token.Error() != nil {
return token.Error()
}
if ok := token.WaitTimeout(pub.timeout); !ok {
return errPublishTimeout
}
return nil
}
func (pub publisher) Close() error {
pub.client.Disconnect(uint(pub.timeout))
return nil
}
-230
View File
@@ -1,230 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mqtt
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"github.com/absmach/magistrala/pkg/messaging"
mqtt "github.com/eclipse/paho.mqtt.golang"
"google.golang.org/protobuf/proto"
)
const username = "magistrala-mqtt"
var (
// ErrConnect indicates that connection to MQTT broker failed.
ErrConnect = errors.New("failed to connect to MQTT broker")
// errSubscribeTimeout indicates that the subscription failed due to timeout.
errSubscribeTimeout = errors.New("failed to subscribe due to timeout reached")
// errUnsubscribeTimeout indicates that unsubscribe failed due to timeout.
errUnsubscribeTimeout = errors.New("failed to unsubscribe due to timeout reached")
// errUnsubscribeDeleteTopic indicates that unsubscribe failed because the topic was deleted.
errUnsubscribeDeleteTopic = errors.New("failed to unsubscribe due to deletion of topic")
// ErrNotSubscribed indicates that the topic is not subscribed to.
ErrNotSubscribed = errors.New("not subscribed")
// ErrEmptyTopic indicates the absence of topic.
ErrEmptyTopic = errors.New("empty topic")
// ErrEmptyID indicates the absence of ID.
ErrEmptyID = errors.New("empty ID")
)
var _ messaging.PubSub = (*pubsub)(nil)
type subscription struct {
client mqtt.Client
topics []string
cancel func() error
}
type pubsub struct {
publisher
logger *slog.Logger
mu sync.RWMutex
address string
timeout time.Duration
subscriptions map[string]subscription
}
// NewPubSub returns MQTT message publisher/subscriber.
func NewPubSub(url string, qos uint8, timeout time.Duration, logger *slog.Logger) (messaging.PubSub, error) {
client, err := newClient(url, "mqtt-publisher", timeout)
if err != nil {
return nil, err
}
ret := &pubsub{
publisher: publisher{
client: client,
timeout: timeout,
qos: qos,
},
address: url,
timeout: timeout,
logger: logger,
subscriptions: make(map[string]subscription),
}
return ret, nil
}
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
if cfg.ID == "" {
return ErrEmptyID
}
if cfg.Topic == "" {
return ErrEmptyTopic
}
ps.mu.Lock()
defer ps.mu.Unlock()
s, ok := ps.subscriptions[cfg.ID]
// If the client exists, check if it's subscribed to the topic and unsubscribe if needed.
switch ok {
case true:
if ok := s.contains(cfg.Topic); ok {
if err := s.unsubscribe(cfg.Topic, ps.timeout); err != nil {
return err
}
}
default:
client, err := newClient(ps.address, cfg.ID, ps.timeout)
if err != nil {
return err
}
s = subscription{
client: client,
topics: []string{},
cancel: cfg.Handler.Cancel,
}
}
s.topics = append(s.topics, cfg.Topic)
ps.subscriptions[cfg.ID] = s
token := s.client.Subscribe(cfg.Topic, byte(ps.qos), ps.mqttHandler(cfg.Handler))
if token.Error() != nil {
return token.Error()
}
if ok := token.WaitTimeout(ps.timeout); !ok {
return errSubscribeTimeout
}
return nil
}
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
if id == "" {
return ErrEmptyID
}
if topic == "" {
return ErrEmptyTopic
}
ps.mu.Lock()
defer ps.mu.Unlock()
s, ok := ps.subscriptions[id]
if !ok || !s.contains(topic) {
return ErrNotSubscribed
}
if err := s.unsubscribe(topic, ps.timeout); err != nil {
return err
}
ps.subscriptions[id] = s
if len(s.topics) == 0 {
delete(ps.subscriptions, id)
}
return nil
}
func (s *subscription) unsubscribe(topic string, timeout time.Duration) error {
if s.cancel != nil {
if err := s.cancel(); err != nil {
return err
}
}
token := s.client.Unsubscribe(topic)
if token.Error() != nil {
return token.Error()
}
if ok := token.WaitTimeout(timeout); !ok {
return errUnsubscribeTimeout
}
if ok := s.delete(topic); !ok {
return errUnsubscribeDeleteTopic
}
return token.Error()
}
func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) {
opts := mqtt.NewClientOptions().
SetUsername(username).
AddBroker(address).
SetClientID(id)
client := mqtt.NewClient(opts)
token := client.Connect()
if token.Error() != nil {
return nil, token.Error()
}
if ok := token.WaitTimeout(timeout); !ok {
return nil, ErrConnect
}
return client, nil
}
func (ps *pubsub) mqttHandler(h messaging.MessageHandler) mqtt.MessageHandler {
return func(_ mqtt.Client, m mqtt.Message) {
var msg messaging.Message
if err := proto.Unmarshal(m.Payload(), &msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
return
}
if err := h.Handle(&msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
}
}
}
// Contains checks if a topic is present.
func (s subscription) contains(topic string) bool {
return s.indexOf(topic) != -1
}
// Finds the index of an item in the topics.
func (s subscription) indexOf(element string) int {
for k, v := range s.topics {
if element == v {
return k
}
}
return -1
}
// Deletes a topic from the slice.
func (s *subscription) delete(topic string) bool {
index := s.indexOf(topic)
if index == -1 {
return false
}
topics := make([]string, len(s.topics)-1)
copy(topics[:index], s.topics[:index])
copy(topics[index:], s.topics[index+1:])
s.topics = topics
return true
}
-474
View File
@@ -1,474 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mqtt_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/absmach/magistrala/pkg/messaging"
mqttpubsub "github.com/absmach/magistrala/pkg/messaging/mqtt"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
const (
topic = "topic"
chansPrefix = "channels"
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
subtopic = "engine"
tokenTimeout = 100 * time.Millisecond
)
var data = []byte("payload")
// ErrFailedHandleMessage indicates that the message couldn't be handled.
var errFailedHandleMessage = errors.New("failed to handle magistrala message")
func TestPublisher(t *testing.T) {
msgChan := make(chan []byte)
// Subscribing with topic, and with subtopic, so that we can publish messages.
client, err := newClient(address, "clientID1", brokerTimeout)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
token := client.Subscribe(topic, qos, func(_ mqtt.Client, m mqtt.Message) {
msgChan <- m.Payload()
})
if ok := token.WaitTimeout(tokenTimeout); !ok {
assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", topic))
}
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
token = client.Subscribe(fmt.Sprintf("%s.%s", topic, subtopic), qos, func(_ mqtt.Client, m mqtt.Message) {
msgChan <- m.Payload()
})
if ok := token.WaitTimeout(tokenTimeout); !ok {
assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", fmt.Sprintf("%s.%s", topic, subtopic)))
}
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
t.Cleanup(func() {
token := client.Unsubscribe(topic, fmt.Sprintf("%s.%s", topic, subtopic))
token.WaitTimeout(tokenTimeout)
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
client.Disconnect(100)
})
// Test publish with an empty topic.
err = pubsub.Publish(context.TODO(), "", &messaging.Message{Payload: data})
assert.Equal(t, err, mqttpubsub.ErrEmptyTopic, fmt.Sprintf("Publish with empty topic: expected: %s, got: %s", mqttpubsub.ErrEmptyTopic, err))
cases := []struct {
desc string
channel string
subtopic string
payload []byte
}{
{
desc: "publish message with nil payload",
payload: nil,
},
{
desc: "publish message with string payload",
payload: data,
},
{
desc: "publish message with channel",
payload: data,
channel: channel,
},
{
desc: "publish message with subtopic",
payload: data,
subtopic: subtopic,
},
{
desc: "publish message with channel and subtopic",
payload: data,
channel: channel,
subtopic: subtopic,
},
}
for _, tc := range cases {
expectedMsg := messaging.Message{
Publisher: "clientID11",
Channel: tc.channel,
Subtopic: tc.subtopic,
Payload: tc.payload,
}
err := pubsub.Publish(context.TODO(), topic, &expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err))
data, err := proto.Marshal(&expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
receivedMsg := <-msgChan
if tc.payload != nil {
assert.Equal(t, expectedMsg.GetPayload(), receivedMsg, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, data, receivedMsg))
}
}
}
func TestSubscribe(t *testing.T) {
msgChan := make(chan *messaging.Message)
// Creating client to Publish messages to subscribed topic.
client, err := newClient(address, "magistrala", brokerTimeout)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
t.Cleanup(func() {
client.Unsubscribe()
client.Disconnect(100)
})
cases := []struct {
desc string
topic string
clientID string
err error
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: topic,
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1", msgChan},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: topic,
clientID: "clientid2",
err: nil,
handler: handler{false, "clientid2", msgChan},
},
{
desc: "Subscribe to an already subscribed topic with an ID",
topic: topic,
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1", msgChan},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1", msgChan},
},
{
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1", msgChan},
},
{
desc: "Subscribe to an empty topic with an ID",
topic: "",
clientID: "clientid1",
err: mqttpubsub.ErrEmptyTopic,
handler: handler{false, "clientid1", msgChan},
},
{
desc: "Subscribe to a topic with empty id",
topic: topic,
clientID: "",
err: mqttpubsub.ErrEmptyID,
handler: handler{false, "", msgChan},
},
}
for _, tc := range cases {
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: tc.topic,
Handler: tc.handler,
}
err = pubsub.Subscribe(context.TODO(), subCfg)
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
if tc.err == nil {
expectedMsg := messaging.Message{
Publisher: "clientID1",
Channel: channel,
Subtopic: subtopic,
Payload: data,
}
data, err := proto.Marshal(&expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
token := client.Publish(tc.topic, qos, false, data)
token.WaitTimeout(tokenTimeout)
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
receivedMsg := <-msgChan
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
}
}
}
func TestPubSub(t *testing.T) {
msgChan := make(chan *messaging.Message)
cases := []struct {
desc string
topic string
clientID string
err error
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: topic,
clientID: "clientid7",
err: nil,
handler: handler{false, "clientid7", msgChan},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: topic,
clientID: "clientid8",
err: nil,
handler: handler{false, "clientid8", msgChan},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: "clientid7",
err: nil,
handler: handler{false, "clientid7", msgChan},
},
{
desc: "Subscribe to an empty topic with an ID",
topic: "",
clientID: "clientid7",
err: mqttpubsub.ErrEmptyTopic,
handler: handler{false, "clientid7", msgChan},
},
{
desc: "Subscribe to a topic with empty id",
topic: topic,
clientID: "",
err: mqttpubsub.ErrEmptyID,
handler: handler{false, "", msgChan},
},
}
for _, tc := range cases {
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: tc.topic,
Handler: tc.handler,
}
err := pubsub.Subscribe(context.TODO(), subCfg)
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
if tc.err == nil {
// Use pubsub to subscribe to a topic, and then publish messages to that topic.
expectedMsg := messaging.Message{
Publisher: "clientID",
Channel: channel,
Subtopic: subtopic,
Payload: data,
}
data, err := proto.Marshal(&expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
msg := messaging.Message{
Payload: data,
}
// Publish message, and then receive it on message channel.
err = pubsub.Publish(context.TODO(), topic, &msg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err))
receivedMsg := <-msgChan
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
}
}
}
func TestUnsubscribe(t *testing.T) {
msgChan := make(chan *messaging.Message)
cases := []struct {
desc string
topic string
clientID string
err error
subscribe bool // True for subscribe and false for unsubscribe.
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: nil,
subscribe: true,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid9",
err: nil,
subscribe: true,
handler: handler{false, "clientid9", msgChan},
},
{
desc: "Unsubscribe from a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: nil,
subscribe: false,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Unsubscribe from same topic with different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid9",
err: nil,
subscribe: false,
handler: handler{false, "clientid9", msgChan},
},
{
desc: "Unsubscribe from a non-existent topic with an ID",
topic: "h",
clientID: "clientid4",
err: mqttpubsub.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Unsubscribe from an already unsubscribed topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: mqttpubsub.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd4",
err: nil,
subscribe: true,
handler: handler{false, "clientidd4", msgChan},
},
{
desc: "Unsubscribe from a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd4",
err: nil,
subscribe: false,
handler: handler{false, "clientidd4", msgChan},
},
{
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientid4",
err: mqttpubsub.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Unsubscribe from an empty topic with an ID",
topic: "",
clientID: "clientid4",
err: mqttpubsub.ErrEmptyTopic,
subscribe: false,
handler: handler{false, "clientid4", msgChan},
},
{
desc: "Unsubscribe from a topic with empty ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "",
err: mqttpubsub.ErrEmptyID,
subscribe: false,
handler: handler{false, "", msgChan},
},
{
desc: "Subscribe to a new topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
clientID: "clientid55",
err: nil,
subscribe: true,
handler: handler{true, "clientid5", msgChan},
},
{
desc: "Unsubscribe from a topic with an ID with failing handler",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
clientID: "clientid55",
err: errFailedHandleMessage,
subscribe: false,
handler: handler{true, "clientid5", msgChan},
},
{
desc: "Subscribe to a new topic with subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
clientID: "clientid55",
err: nil,
subscribe: true,
handler: handler{true, "clientid5", msgChan},
},
{
desc: "Unsubscribe from a topic with subtopic with an ID with failing handler",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
clientID: "clientid55",
err: errFailedHandleMessage,
subscribe: false,
handler: handler{true, "clientid5", msgChan},
},
}
for _, tc := range cases {
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: tc.topic,
Handler: tc.handler,
}
switch tc.subscribe {
case true:
err := pubsub.Subscribe(context.TODO(), subCfg)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
default:
err := pubsub.Unsubscribe(context.TODO(), tc.clientID, tc.topic)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
}
}
}
type handler struct {
fail bool
publisher string
msgChan chan *messaging.Message
}
func (h handler) Handle(msg *messaging.Message) error {
if msg.GetPublisher() != h.publisher {
h.msgChan <- msg
}
return nil
}
func (h handler) Cancel() error {
if h.fail {
return errFailedHandleMessage
}
return nil
}
-121
View File
@@ -1,121 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mqtt_test
import (
"fmt"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
"testing"
"time"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/messaging"
mqttpubsub "github.com/absmach/magistrala/pkg/messaging/mqtt"
mqtt "github.com/eclipse/paho.mqtt.golang"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
var (
pubsub messaging.PubSub
logger *slog.Logger
address string
)
const (
username = "magistrala-mqtt"
qos = 2
port = "1883/tcp"
brokerTimeout = 30 * time.Second
poolMaxWait = 120 * time.Second
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "eclipse-mosquitto",
Tag: "1.6.15",
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
address = fmt.Sprintf("%s:%s", "localhost", container.GetPort(port))
pool.MaxWait = poolMaxWait
logger, err = mglog.New(os.Stdout, "debug")
if err != nil {
log.Fatal(err.Error())
}
if err := pool.Retry(func() error {
pubsub, err = mqttpubsub.NewPubSub(address, 2, brokerTimeout, logger)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
defer func() {
err = pubsub.Close()
if err != nil {
log.Fatal(err.Error())
}
}()
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) {
opts := mqtt.NewClientOptions().
SetUsername(username).
AddBroker(address).
SetClientID(id)
client := mqtt.NewClient(opts)
token := client.Connect()
if token.Error() != nil {
return nil, token.Error()
}
ok := token.WaitTimeout(timeout)
if !ok {
return nil, mqttpubsub.ErrConnect
}
if token.Error() != nil {
return nil, token.Error()
}
return client, nil
}
-11
View File
@@ -1,11 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package nats hold the implementation of the Publisher and PubSub
// interfaces for the NATS messaging system, the internal messaging
// broker of the Magistrala IoT platform. Due to the practical requirements
// implementation Publisher is created alongside PubSub. The reason for
// this is that Subscriber implementation of NATS brings the burden of
// additional struct fields which are not used by Publisher. Subscriber
// is not implemented separately because PubSub can be used where Subscriber is needed.
package nats
-56
View File
@@ -1,56 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats
import (
"errors"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/nats-io/nats.go/jetstream"
)
// ErrInvalidType is returned when the provided value is not of the expected type.
var ErrInvalidType = errors.New("invalid type")
// Prefix sets the prefix for the publisher.
func Prefix(prefix string) messaging.Option {
return func(val interface{}) error {
p, ok := val.(*publisher)
if !ok {
return ErrInvalidType
}
p.prefix = prefix
return nil
}
}
// JSStream sets the JetStream for the publisher.
func JSStream(stream jetstream.JetStream) messaging.Option {
return func(val interface{}) error {
p, ok := val.(*publisher)
if !ok {
return ErrInvalidType
}
p.js = stream
return nil
}
}
// Stream sets the Stream for the subscriber.
func Stream(stream jetstream.Stream) messaging.Option {
return func(val interface{}) error {
p, ok := val.(*pubsub)
if !ok {
return ErrInvalidType
}
p.stream = stream
return nil
}
}
-88
View File
@@ -1,88 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats
import (
"context"
"fmt"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"google.golang.org/protobuf/proto"
)
const (
// A maximum number of reconnect attempts before NATS connection closes permanently.
// Value -1 represents an unlimited number of reconnect retries, i.e. the client
// will never give up on retrying to re-establish connection to NATS server.
maxReconnects = -1
// reconnectBufSize is obtained from the maximum number of unpublished events
// multiplied by the approximate maximum size of a single event.
reconnectBufSize = events.MaxUnpublishedEvents * (1024 * 1024)
)
var _ messaging.Publisher = (*publisher)(nil)
type publisher struct {
js jetstream.JetStream
conn *broker.Conn
prefix string
}
// NewPublisher returns NATS message Publisher.
func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
conn, err := broker.Connect(url, broker.MaxReconnects(maxReconnects), broker.ReconnectBufSize(int(reconnectBufSize)))
if err != nil {
return nil, err
}
js, err := jetstream.New(conn)
if err != nil {
return nil, err
}
if _, err := js.CreateStream(ctx, jsStreamConfig); err != nil {
return nil, err
}
ret := &publisher{
js: js,
conn: conn,
prefix: chansPrefix,
}
for _, opt := range opts {
if err := opt(ret); err != nil {
return nil, err
}
}
return ret, nil
}
func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
if topic == "" {
return ErrEmptyTopic
}
data, err := proto.Marshal(msg)
if err != nil {
return err
}
subject := fmt.Sprintf("%s.%s", pub.prefix, topic)
if msg.GetSubtopic() != "" {
subject = fmt.Sprintf("%s.%s", subject, msg.GetSubtopic())
}
_, err = pub.js.Publish(ctx, subject, data)
return err
}
func (pub *publisher) Close() error {
pub.conn.Close()
return nil
}
-174
View File
@@ -1,174 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"time"
"github.com/absmach/magistrala/pkg/messaging"
broker "github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"google.golang.org/protobuf/proto"
)
const chansPrefix = "channels"
// Publisher and Subscriber errors.
var (
ErrNotSubscribed = errors.New("not subscribed")
ErrEmptyTopic = errors.New("empty topic")
ErrEmptyID = errors.New("empty id")
jsStreamConfig = jetstream.StreamConfig{
Name: "channels",
Description: "Magistrala stream for sending and receiving messages in between Magistrala channels",
Subjects: []string{"channels.>"},
Retention: jetstream.LimitsPolicy,
MaxMsgsPerSubject: 1e6,
MaxAge: time.Hour * 24,
MaxMsgSize: 1024 * 1024,
Discard: jetstream.DiscardOld,
Storage: jetstream.FileStorage,
}
)
var _ messaging.PubSub = (*pubsub)(nil)
type pubsub struct {
publisher
logger *slog.Logger
stream jetstream.Stream
}
// NewPubSub returns NATS message publisher/subscriber.
// Parameter queue specifies the queue for the Subscribe method.
// If queue is specified (is not an empty string), Subscribe method
// will execute NATS QueueSubscribe which is conceptually different
// from ordinary subscribe. For more information, please take a look
// here: https://docs.nats.io/developing-with-nats/receiving/queues.
// If the queue is empty, Subscribe will be used.
func NewPubSub(ctx context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
conn, err := broker.Connect(url, broker.MaxReconnects(maxReconnects))
if err != nil {
return nil, err
}
js, err := jetstream.New(conn)
if err != nil {
return nil, err
}
stream, err := js.CreateStream(ctx, jsStreamConfig)
if err != nil {
return nil, err
}
ret := &pubsub{
publisher: publisher{
js: js,
conn: conn,
prefix: chansPrefix,
},
stream: stream,
logger: logger,
}
for _, opt := range opts {
if err := opt(ret); err != nil {
return nil, err
}
}
return ret, nil
}
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
if cfg.ID == "" {
return ErrEmptyID
}
if cfg.Topic == "" {
return ErrEmptyTopic
}
nh := ps.natsHandler(cfg.Handler)
consumerConfig := jetstream.ConsumerConfig{
Name: formatConsumerName(cfg.Topic, cfg.ID),
Durable: formatConsumerName(cfg.Topic, cfg.ID),
Description: fmt.Sprintf("Magistrala consumer of id %s for cfg.Topic %s", cfg.ID, cfg.Topic),
DeliverPolicy: jetstream.DeliverNewPolicy,
FilterSubject: cfg.Topic,
}
switch cfg.DeliveryPolicy {
case messaging.DeliverNewPolicy:
consumerConfig.DeliverPolicy = jetstream.DeliverNewPolicy
case messaging.DeliverAllPolicy:
consumerConfig.DeliverPolicy = jetstream.DeliverAllPolicy
}
consumer, err := ps.stream.CreateOrUpdateConsumer(ctx, consumerConfig)
if err != nil {
return fmt.Errorf("failed to create consumer: %w", err)
}
if _, err = consumer.Consume(nh); err != nil {
return fmt.Errorf("failed to consume: %w", err)
}
return nil
}
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
if id == "" {
return ErrEmptyID
}
if topic == "" {
return ErrEmptyTopic
}
err := ps.stream.DeleteConsumer(ctx, formatConsumerName(topic, id))
switch {
case errors.Is(err, jetstream.ErrConsumerNotFound):
return ErrNotSubscribed
default:
return err
}
}
func (ps *pubsub) natsHandler(h messaging.MessageHandler) func(m jetstream.Msg) {
return func(m jetstream.Msg) {
var msg messaging.Message
if err := proto.Unmarshal(m.Data(), &msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
return
}
if err := h.Handle(&msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
}
if err := m.Ack(); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
}
}
}
func formatConsumerName(topic, id string) string {
// A durable name cannot contain whitespace, ., *, >, path separators (forward or backwards slash), and non-printable characters.
chars := []string{
" ", "_",
".", "_",
"*", "_",
">", "_",
"/", "_",
"\\", "_",
}
topic = strings.NewReplacer(chars...).Replace(topic)
return fmt.Sprintf("%s-%s", topic, id)
}
-297
View File
@@ -1,297 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/nats"
"github.com/stretchr/testify/assert"
)
const (
topic = "topic"
chansPrefix = "channels"
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
subtopic = "engine"
clientID = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
)
var (
msgChan = make(chan *messaging.Message)
message = &messaging.Message{
Channel: channel,
Subtopic: subtopic,
Publisher: "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b",
Protocol: "mqtt",
Payload: []byte("payload"),
Created: time.Now().UnixNano(),
}
)
func TestPublisher(t *testing.T) {
subCfg := messaging.SubscriberConfig{
ID: clientID,
Topic: fmt.Sprintf("%s.>", chansPrefix),
Handler: handler{},
}
err := pubsub.Subscribe(context.TODO(), subCfg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
cases := []struct {
desc string
topic string
subtopic string
message *messaging.Message
error error
}{
{
desc: "publish message with empty message",
topic: channel,
subtopic: subtopic,
message: &messaging.Message{},
error: nil,
},
{
desc: "publish message with message",
topic: channel,
subtopic: subtopic,
message: message,
error: nil,
},
{
desc: "publish message with topic and empty subtopic",
topic: channel,
subtopic: "",
message: message,
error: nil,
},
{
desc: "publish message with subtopic and empty topic",
topic: "",
subtopic: subtopic,
message: message,
error: nats.ErrEmptyTopic,
},
{
desc: "publish message with topic and subtopic",
topic: channel,
subtopic: subtopic,
message: message,
error: nil,
},
}
for _, tc := range cases {
tc.message.Subtopic = tc.subtopic
err := pubsub.Publish(context.TODO(), tc.topic, tc.message)
assert.Equal(t, tc.error, err, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.error, err))
if err == nil {
receivedMsg := <-msgChan
assert.Equal(t, tc.message.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.message.Payload, receivedMsg))
assert.Equal(t, tc.message.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
assert.Equal(t, tc.message.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
assert.Equal(t, tc.message.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
assert.Equal(t, tc.message.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
assert.Equal(t, tc.message.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
assert.Equal(t, tc.message.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
}
}
}
func TestPubsub(t *testing.T) {
// Test Subscribe and Unsubscribe.
subcases := []struct {
desc string
topic string
clientID string
errorMessage error
pubsub bool // true for subscribe and false for unsubscribe.
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Subscribe using malformed topic and ID",
topic: fmt.Sprintf("%s.>", chansPrefix),
clientID: "clientid1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Subscribe using malformed topic and ID",
topic: fmt.Sprintf("%s.*", chansPrefix),
clientID: "clientid1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid2",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Subscribe to an already subscribed topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Unsubscribe from a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid1",
errorMessage: nil,
pubsub: false,
handler: handler{},
},
{
desc: "Unsubscribe from a non-existent topic with an ID",
topic: "h",
clientID: "clientid1",
errorMessage: nats.ErrNotSubscribed,
pubsub: false,
handler: handler{},
},
{
desc: "Unsubscribe from the same topic with a different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientidd2",
errorMessage: nats.ErrNotSubscribed,
pubsub: false,
handler: handler{},
},
{
desc: "Unsubscribe from the same topic with a different ID not subscribed",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientidd3",
errorMessage: nats.ErrNotSubscribed,
pubsub: false,
handler: handler{},
},
{
desc: "Unsubscribe from an already unsubscribed topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid1",
errorMessage: nats.ErrNotSubscribed,
pubsub: false,
handler: handler{},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd1",
errorMessage: nil,
pubsub: true,
handler: handler{},
},
{
desc: "Unsubscribe from a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd1",
errorMessage: nil,
pubsub: false,
handler: handler{},
},
{
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientid1",
errorMessage: nats.ErrNotSubscribed,
pubsub: false,
handler: handler{},
},
{
desc: "Subscribe to an empty topic with an ID",
topic: "",
clientID: "clientid1",
errorMessage: nats.ErrEmptyTopic,
pubsub: true,
handler: handler{},
},
{
desc: "Unsubscribe from an empty topic with an ID",
topic: "",
clientID: "clientid1",
errorMessage: nats.ErrEmptyTopic,
pubsub: false,
handler: handler{},
},
{
desc: "Subscribe to a topic with empty id",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "",
errorMessage: nats.ErrEmptyID,
pubsub: true,
handler: handler{},
},
{
desc: "Unsubscribe from a topic with empty id",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "",
errorMessage: nats.ErrEmptyID,
pubsub: false,
handler: handler{},
},
}
for _, pc := range subcases {
subCfg := messaging.SubscriberConfig{
ID: pc.clientID,
Topic: pc.topic,
Handler: pc.handler,
}
if pc.pubsub == true {
err := pubsub.Subscribe(context.TODO(), subCfg)
if pc.errorMessage == nil {
assert.Nil(t, err, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
} else {
assert.Equal(t, err, pc.errorMessage, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
}
} else {
err := pubsub.Unsubscribe(context.TODO(), pc.clientID, pc.topic)
if pc.errorMessage == nil {
assert.Nil(t, err, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
} else {
assert.Equal(t, err, pc.errorMessage, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
}
}
}
}
type handler struct{}
func (h handler) Handle(msg *messaging.Message) error {
msgChan <- msg
return nil
}
func (h handler) Cancel() error {
return nil
}
-80
View File
@@ -1,80 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package nats_test
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"syscall"
"testing"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/nats"
"github.com/ory/dockertest/v3"
)
var (
publisher messaging.Publisher
pubsub messaging.PubSub
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "nats",
Tag: "2.10.9-alpine",
Cmd: []string{"-DVV", "-js"},
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
address := fmt.Sprintf("nats://%s:%s", "localhost", container.GetPort("4222/tcp"))
if err := pool.Retry(func() error {
publisher, err = nats.NewPublisher(context.Background(), address)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
logger, err := mglog.New(os.Stdout, "error")
if err != nil {
log.Fatal(err.Error())
}
if err := pool.Retry(func() error {
pubsub, err = nats.NewPubSub(context.Background(), address, logger)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package tracing provides tracing instrumentation for Magistrala things policies service.
//
// This package provides tracing middleware for Magistrala things policies service.
// It can be used to trace incoming requests and add tracing capabilities to
// Magistrala things policies service.
//
// For more details about tracing instrumentation for Magistrala messaging refer
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
package tracing
-52
View File
@@ -1,52 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// Traced operations.
const publishOP = "publish"
var defaultAttributes = []attribute.KeyValue{
attribute.String("messaging.system", "nats"),
attribute.String("network.protocol.name", "nats"),
attribute.String("network.protocol.version", "2.2.4"),
}
var _ messaging.Publisher = (*publisherMiddleware)(nil)
type publisherMiddleware struct {
publisher messaging.Publisher
tracer trace.Tracer
host server.Config
}
func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
pub := &publisherMiddleware{
publisher: publisher,
tracer: tracer,
host: config,
}
return pub
}
func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return pm.publisher.Publish(ctx, topic, msg)
}
func (pm *publisherMiddleware) Close() error {
return pm.publisher.Close()
}
-96
View File
@@ -1,96 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/trace"
)
// Constants to define different operations to be traced.
const (
subscribeOP = "receive"
unsubscribeOp = "unsubscribe" // This is not specified in the open telemetry spec.
processOp = "process"
)
var _ messaging.PubSub = (*pubsubMiddleware)(nil)
type pubsubMiddleware struct {
publisherMiddleware
pubsub messaging.PubSub
host server.Config
}
// NewPubSub creates a new pubsub middleware that traces pubsub operations.
func NewPubSub(config server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
pb := &pubsubMiddleware{
publisherMiddleware: publisherMiddleware{
publisher: pubsub,
tracer: tracer,
host: config,
},
pubsub: pubsub,
host: config,
}
return pb
}
// Subscribe creates a new subscription and traces the operation.
func (pm *pubsubMiddleware) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
ctx, span := tracing.CreateSpan(ctx, subscribeOP, cfg.ID, cfg.Topic, "", 0, pm.host, trace.SpanKindClient, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
cfg.Handler = &traceHandler{
ctx: ctx,
handler: cfg.Handler,
tracer: pm.tracer,
host: pm.host,
topic: cfg.Topic,
clientID: cfg.ID,
}
return pm.pubsub.Subscribe(ctx, cfg)
}
// Unsubscribe removes an existing subscription and traces the operation.
func (pm *pubsubMiddleware) Unsubscribe(ctx context.Context, id, topic string) error {
ctx, span := tracing.CreateSpan(ctx, unsubscribeOp, id, topic, "", 0, pm.host, trace.SpanKindInternal, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return pm.pubsub.Unsubscribe(ctx, id, topic)
}
// TraceHandler is used to trace the message handling operation.
type traceHandler struct {
ctx context.Context
handler messaging.MessageHandler
tracer trace.Tracer
host server.Config
topic string
clientID string
}
// Handle instruments the message handling operation.
func (h *traceHandler) Handle(msg *messaging.Message) error {
_, span := tracing.CreateSpan(h.ctx, processOp, h.clientID, h.topic, msg.GetSubtopic(), len(msg.GetPayload()), h.host, trace.SpanKindConsumer, h.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return h.handler.Handle(msg)
}
// Cancel cancels the message handling operation.
func (h *traceHandler) Cancel() error {
return h.handler.Cancel()
}
-82
View File
@@ -1,82 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package messaging
import "context"
type DeliveryPolicy uint8
const (
// DeliverNewPolicy will only deliver new messages that are sent after the consumer is created.
// This is the default policy.
DeliverNewPolicy DeliveryPolicy = iota
// DeliverAllPolicy starts delivering messages from the very beginning of a stream.
DeliverAllPolicy
)
// Publisher specifies message publishing API.
type Publisher interface {
// Publishes message to the stream.
Publish(ctx context.Context, topic string, msg *Message) error
// Close gracefully closes message publisher's connection.
Close() error
}
// MessageHandler represents Message handler for Subscriber.
type MessageHandler interface {
// Handle handles messages passed by underlying implementation.
Handle(msg *Message) error
// Cancel is used for cleanup during unsubscribing and it's optional.
Cancel() error
}
type SubscriberConfig struct {
ID string
Topic string
Handler MessageHandler
DeliveryPolicy DeliveryPolicy
}
// Subscriber specifies message subscription API.
type Subscriber interface {
// Subscribe subscribes to the message stream and consumes messages.
Subscribe(ctx context.Context, cfg SubscriberConfig) error
// Unsubscribe unsubscribes from the message stream and
// stops consuming messages.
Unsubscribe(ctx context.Context, id, topic string) error
// Close gracefully closes message subscriber's connection.
Close() error
}
// PubSub represents aggregation interface for publisher and subscriber.
//
//go:generate mockery --name PubSub --filename pubsub.go --quiet --note "Copyright (c) Abstract Machines"
type PubSub interface {
Publisher
Subscriber
}
// Option represents optional configuration for message broker.
//
// This is used to provide optional configuration parameters to the
// underlying publisher and pubsub implementation so that it can be
// configured to meet the specific needs.
//
// For example, it can be used to set the message prefix so that
// brokers can be used for event sourcing as well as internal message broker.
// Using value of type interface is not recommended but is the most suitable
// for this use case as options should be compiled with respect to the
// underlying broker which can either be RabbitMQ or NATS.
//
// The example below shows how to set the prefix and jetstream stream for NATS.
//
// Example:
//
// broker.NewPublisher(ctx, url, broker.Prefix(eventsPrefix), broker.JSStream(js))
type Option func(vals interface{}) error
-11
View File
@@ -1,11 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package rabbitmq holds the implementation of the Publisher and PubSub
// interfaces for the RabbitMQ messaging system, the internal messaging
// broker of the Magistrala IoT platform. Due to the practical requirements
// implementation Publisher is created alongside PubSub. The reason for
// this is that Subscriber implementation of RabbitMQ brings the burden of
// additional struct fields which are not used by Publisher. Subscriber
// is not implemented separately because PubSub can be used where Subscriber is needed.
package rabbitmq
-60
View File
@@ -1,60 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq
import (
"errors"
"github.com/absmach/magistrala/pkg/messaging"
amqp "github.com/rabbitmq/amqp091-go"
)
// ErrInvalidType is returned when the provided value is not of the expected type.
var ErrInvalidType = errors.New("invalid type")
// Prefix sets the prefix for the publisher.
func Prefix(prefix string) messaging.Option {
return func(val interface{}) error {
p, ok := val.(*publisher)
if !ok {
return ErrInvalidType
}
p.prefix = prefix
return nil
}
}
// Channel sets the channel for the publisher or subscriber.
func Channel(channel *amqp.Channel) messaging.Option {
return func(val interface{}) error {
switch v := val.(type) {
case *publisher:
v.channel = channel
case *pubsub:
v.channel = channel
default:
return ErrInvalidType
}
return nil
}
}
// Exchange sets the exchange for the publisher or subscriber.
func Exchange(exchange string) messaging.Option {
return func(val interface{}) error {
switch v := val.(type) {
case *publisher:
v.exchange = exchange
case *pubsub:
v.exchange = exchange
default:
return ErrInvalidType
}
return nil
}
}
-95
View File
@@ -1,95 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq
import (
"context"
"fmt"
"strings"
"github.com/absmach/magistrala/pkg/messaging"
amqp "github.com/rabbitmq/amqp091-go"
"google.golang.org/protobuf/proto"
)
var _ messaging.Publisher = (*publisher)(nil)
type publisher struct {
conn *amqp.Connection
channel *amqp.Channel
prefix string
exchange string
}
// NewPublisher returns RabbitMQ message Publisher.
func NewPublisher(url string, opts ...messaging.Option) (messaging.Publisher, error) {
conn, err := amqp.Dial(url)
if err != nil {
return nil, err
}
ch, err := conn.Channel()
if err != nil {
return nil, err
}
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
return nil, err
}
ret := &publisher{
conn: conn,
channel: ch,
prefix: chansPrefix,
exchange: exchangeName,
}
for _, opt := range opts {
if err := opt(ret); err != nil {
return nil, err
}
}
return ret, nil
}
func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
if topic == "" {
return ErrEmptyTopic
}
data, err := proto.Marshal(msg)
if err != nil {
return err
}
subject := fmt.Sprintf("%s.%s", pub.prefix, topic)
if msg.GetSubtopic() != "" {
subject = fmt.Sprintf("%s.%s", subject, msg.GetSubtopic())
}
subject = formatTopic(subject)
err = pub.channel.PublishWithContext(
ctx,
pub.exchange,
subject,
false,
false,
amqp.Publishing{
Headers: amqp.Table{},
ContentType: "application/octet-stream",
AppId: "magistrala-publisher",
Body: data,
})
if err != nil {
return err
}
return nil
}
func (pub *publisher) Close() error {
return pub.conn.Close()
}
func formatTopic(topic string) string {
return strings.ReplaceAll(topic, ">", "#")
}
-191
View File
@@ -1,191 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"github.com/absmach/magistrala/pkg/messaging"
amqp "github.com/rabbitmq/amqp091-go"
"google.golang.org/protobuf/proto"
)
const (
// SubjectAllChannels represents subject to subscribe for all the channels.
SubjectAllChannels = "channels.#"
exchangeName = "messages"
chansPrefix = "channels"
)
var (
// ErrNotSubscribed indicates that the topic is not subscribed to.
ErrNotSubscribed = errors.New("not subscribed")
// ErrEmptyTopic indicates the absence of topic.
ErrEmptyTopic = errors.New("empty topic")
// ErrEmptyID indicates the absence of ID.
ErrEmptyID = errors.New("empty ID")
)
var _ messaging.PubSub = (*pubsub)(nil)
type subscription struct {
cancel func() error
}
type pubsub struct {
publisher
logger *slog.Logger
subscriptions map[string]map[string]subscription
mu sync.Mutex
}
// NewPubSub returns RabbitMQ message publisher/subscriber.
func NewPubSub(url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
conn, err := amqp.Dial(url)
if err != nil {
return nil, err
}
ch, err := conn.Channel()
if err != nil {
return nil, err
}
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
return nil, err
}
ret := &pubsub{
publisher: publisher{
conn: conn,
channel: ch,
exchange: exchangeName,
prefix: chansPrefix,
},
logger: logger,
subscriptions: make(map[string]map[string]subscription),
}
for _, opt := range opts {
if err := opt(ret); err != nil {
return nil, err
}
}
return ret, nil
}
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
if cfg.ID == "" {
return ErrEmptyID
}
if cfg.Topic == "" {
return ErrEmptyTopic
}
ps.mu.Lock()
cfg.Topic = formatTopic(cfg.Topic)
// Check topic
s, ok := ps.subscriptions[cfg.Topic]
if ok {
// Check client ID
if _, ok := s[cfg.ID]; ok {
// Unlocking, so that Unsubscribe() can access ps.subscriptions
ps.mu.Unlock()
if err := ps.Unsubscribe(ctx, cfg.ID, cfg.Topic); err != nil {
return err
}
ps.mu.Lock()
// value of s can be changed while ps.mu is unlocked
s = ps.subscriptions[cfg.Topic]
}
}
defer ps.mu.Unlock()
if s == nil {
s = make(map[string]subscription)
ps.subscriptions[cfg.Topic] = s
}
clientID := fmt.Sprintf("%s-%s", cfg.Topic, cfg.ID)
queue, err := ps.channel.QueueDeclare(clientID, true, false, false, false, nil)
if err != nil {
return err
}
if err := ps.channel.QueueBind(queue.Name, cfg.Topic, ps.exchange, false, nil); err != nil {
return err
}
msgs, err := ps.channel.Consume(queue.Name, clientID, true, false, false, false, nil)
if err != nil {
return err
}
go ps.handle(msgs, cfg.Handler)
s[cfg.ID] = subscription{
cancel: func() error {
if err := ps.channel.Cancel(clientID, false); err != nil {
return err
}
return cfg.Handler.Cancel()
},
}
return nil
}
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
if id == "" {
return ErrEmptyID
}
if topic == "" {
return ErrEmptyTopic
}
ps.mu.Lock()
defer ps.mu.Unlock()
topic = formatTopic(topic)
// Check topic
s, ok := ps.subscriptions[topic]
if !ok {
return ErrNotSubscribed
}
// Check topic ID
current, ok := s[id]
if !ok {
return ErrNotSubscribed
}
if current.cancel != nil {
if err := current.cancel(); err != nil {
return err
}
}
if err := ps.channel.QueueUnbind(topic, topic, exchangeName, nil); err != nil {
return err
}
delete(s, id)
if len(s) == 0 {
delete(ps.subscriptions, topic)
}
return nil
}
func (ps *pubsub) handle(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) {
for d := range deliveries {
var msg messaging.Message
if err := proto.Unmarshal(d.Body, &msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
return
}
if err := h.Handle(&msg); err != nil {
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
return
}
}
}
-460
View File
@@ -1,460 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq_test
import (
"context"
"errors"
"fmt"
"testing"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
const (
topic = "topic"
chansPrefix = "channels"
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
subtopic = "engine"
clientID = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
exchangeName = "messages"
)
var (
msgChan = make(chan *messaging.Message)
data = []byte("payload")
)
var errFailedHandleMessage = errors.New("failed to handle magistrala message")
func TestPublisher(t *testing.T) {
// Subscribing with topic, and with subtopic, so that we can publish messages.
conn, ch, err := newConn()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
topicChan := subscribe(t, ch, fmt.Sprintf("%s.%s", chansPrefix, topic))
subtopicChan := subscribe(t, ch, fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic))
go rabbitHandler(topicChan, handler{})
go rabbitHandler(subtopicChan, handler{})
t.Cleanup(func() {
conn.Close()
ch.Close()
})
cases := []struct {
desc string
channel string
subtopic string
payload []byte
}{
{
desc: "publish message with nil payload",
payload: nil,
},
{
desc: "publish message with string payload",
payload: data,
},
{
desc: "publish message with channel",
payload: data,
channel: channel,
},
{
desc: "publish message with subtopic",
payload: data,
subtopic: subtopic,
},
{
desc: "publish message with channel and subtopic",
payload: data,
channel: channel,
subtopic: subtopic,
},
}
for _, tc := range cases {
expectedMsg := messaging.Message{
Publisher: clientID,
Channel: tc.channel,
Subtopic: tc.subtopic,
Payload: tc.payload,
}
err = pubsub.Publish(context.TODO(), topic, &expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s", tc.desc, err))
receivedMsg := <-msgChan
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
}
}
func TestSubscribe(t *testing.T) {
// Creating rabbitmq connection and channel, so that we can publish messages.
conn, ch, err := newConn()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
t.Cleanup(func() {
conn.Close()
ch.Close()
})
cases := []struct {
desc string
topic string
clientID string
err error
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: topic,
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1"},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: topic,
clientID: "clientid2",
err: nil,
handler: handler{false, "clientid2"},
},
{
desc: "Subscribe to an already subscribed topic with an ID",
topic: topic,
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1"},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1"},
},
{
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: "clientid1",
err: nil,
handler: handler{false, "clientid1"},
},
{
desc: "Subscribe to an empty topic with an ID",
topic: "",
clientID: "clientid1",
err: rabbitmq.ErrEmptyTopic,
handler: handler{false, "clientid1"},
},
{
desc: "Subscribe to a topic with empty id",
topic: topic,
clientID: "",
err: rabbitmq.ErrEmptyID,
handler: handler{false, ""},
},
}
for _, tc := range cases {
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: tc.topic,
Handler: tc.handler,
}
err := pubsub.Subscribe(context.TODO(), subCfg)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
if tc.err == nil {
expectedMsg := messaging.Message{
Publisher: "CLIENTID",
Channel: channel,
Subtopic: subtopic,
Payload: data,
}
data, err := proto.Marshal(&expectedMsg)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
err = ch.PublishWithContext(
context.Background(),
exchangeName,
tc.topic,
false,
false,
amqp.Publishing{
Headers: amqp.Table{},
ContentType: "application/octet-stream",
AppId: "magistrala-publisher",
Body: data,
})
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
receivedMsg := <-msgChan
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
}
}
}
func TestUnsubscribe(t *testing.T) {
// Test Subscribe and Unsubscribe
cases := []struct {
desc string
topic string
clientID string
err error
subscribe bool // True for subscribe and false for unsubscribe.
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: nil,
subscribe: true,
handler: handler{false, "clientid4"},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid9",
err: nil,
subscribe: true,
handler: handler{false, "clientid9"},
},
{
desc: "Unsubscribe from a topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: nil,
subscribe: false,
handler: handler{false, "clientid4"},
},
{
desc: "Unsubscribe from same topic with different ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid9",
err: nil,
subscribe: false,
handler: handler{false, "clientid9"},
},
{
desc: "Unsubscribe from a non-existent topic with an ID",
topic: "h",
clientID: "clientid4",
err: rabbitmq.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4"},
},
{
desc: "Unsubscribe from an already unsubscribed topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "clientid4",
err: rabbitmq.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4"},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd4",
err: nil,
subscribe: true,
handler: handler{false, "clientidd4"},
},
{
desc: "Unsubscribe from a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientidd4",
err: nil,
subscribe: false,
handler: handler{false, "clientidd4"},
},
{
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
clientID: "clientid4",
err: rabbitmq.ErrNotSubscribed,
subscribe: false,
handler: handler{false, "clientid4"},
},
{
desc: "Unsubscribe from an empty topic with an ID",
topic: "",
clientID: "clientid4",
err: rabbitmq.ErrEmptyTopic,
subscribe: false,
handler: handler{false, "clientid4"},
},
{
desc: "Unsubscribe from a topic with empty ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
clientID: "",
err: rabbitmq.ErrEmptyID,
subscribe: false,
handler: handler{false, ""},
},
{
desc: "Subscribe to a new topic with an ID",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
clientID: "clientid55",
err: nil,
subscribe: true,
handler: handler{true, "clientid5"},
},
{
desc: "Unsubscribe from a topic with an ID with failing handler",
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
clientID: "clientid55",
err: errFailedHandleMessage,
subscribe: false,
handler: handler{true, "clientid5"},
},
{
desc: "Subscribe to a new topic with subtopic with an ID",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
clientID: "clientid55",
err: nil,
subscribe: true,
handler: handler{true, "clientid5"},
},
{
desc: "Unsubscribe from a topic with subtopic with an ID with failing handler",
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
clientID: "clientid55",
err: errFailedHandleMessage,
subscribe: false,
handler: handler{true, "clientid5"},
},
}
for _, tc := range cases {
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: tc.topic,
Handler: tc.handler,
}
switch tc.subscribe {
case true:
err := pubsub.Subscribe(context.TODO(), subCfg)
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
default:
err := pubsub.Unsubscribe(context.TODO(), tc.clientID, tc.topic)
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
}
}
}
func TestPubSub(t *testing.T) {
cases := []struct {
desc string
topic string
clientID string
err error
handler messaging.MessageHandler
}{
{
desc: "Subscribe to a topic with an ID",
topic: topic,
clientID: clientID,
err: nil,
handler: handler{false, clientID},
},
{
desc: "Subscribe to the same topic with a different ID",
topic: topic,
clientID: clientID + "1",
err: nil,
handler: handler{false, clientID + "1"},
},
{
desc: "Subscribe to a topic with a subtopic with an ID",
topic: fmt.Sprintf("%s.%s", topic, subtopic),
clientID: clientID + "2",
err: nil,
handler: handler{false, clientID + "2"},
},
{
desc: "Subscribe to an empty topic with an ID",
topic: "",
clientID: clientID,
err: rabbitmq.ErrEmptyTopic,
handler: handler{false, clientID},
},
{
desc: "Subscribe to a topic with empty id",
topic: topic,
clientID: "",
err: rabbitmq.ErrEmptyID,
handler: handler{false, ""},
},
}
for _, tc := range cases {
subject := ""
if tc.topic != "" {
subject = fmt.Sprintf("%s.%s", chansPrefix, tc.topic)
}
subCfg := messaging.SubscriberConfig{
ID: tc.clientID,
Topic: subject,
Handler: tc.handler,
}
err := pubsub.Subscribe(context.TODO(), subCfg)
switch tc.err {
case nil:
// If no error, publish message, and receive after subscribing.
expectedMsg := messaging.Message{
Channel: channel,
Payload: data,
}
err = pubsub.Publish(context.TODO(), tc.topic, &expectedMsg)
assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err))
receivedMsg := <-msgChan
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
err = pubsub.Unsubscribe(context.TODO(), tc.clientID, fmt.Sprintf("%s.%s", chansPrefix, tc.topic))
assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err))
default:
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
}
}
}
type handler struct {
fail bool
publisher string
}
func (h handler) Handle(msg *messaging.Message) error {
if msg.GetPublisher() != h.publisher {
msgChan <- msg
}
return nil
}
func (h handler) Cancel() error {
if h.fail {
return errFailedHandleMessage
}
return nil
}
-131
View File
@@ -1,131 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package rabbitmq_test
import (
"fmt"
"log"
"log/slog"
"os"
"os/signal"
"syscall"
"testing"
mglog "github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
"github.com/ory/dockertest/v3"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
const (
port = "5672/tcp"
brokerName = "rabbitmq"
brokerVersion = "3.12.12-alpine"
)
var (
publisher messaging.Publisher
pubsub messaging.PubSub
logger *slog.Logger
address string
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.Run(brokerName, brokerVersion, []string{})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
handleInterrupt(pool, container)
address = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort(port))
if err := pool.Retry(func() error {
publisher, err = rabbitmq.NewPublisher(address)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
logger, err = mglog.New(os.Stdout, "debug")
if err != nil {
log.Fatal(err.Error())
}
if err := pool.Retry(func() error {
pubsub, err = rabbitmq.NewPubSub(address, logger)
return err
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
func newConn() (*amqp.Connection, *amqp.Channel, error) {
conn, err := amqp.Dial(address)
if err != nil {
return nil, nil, err
}
ch, err := conn.Channel()
if err != nil {
return nil, nil, err
}
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
return nil, nil, err
}
return conn, ch, nil
}
func rabbitHandler(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) {
for d := range deliveries {
var msg messaging.Message
if err := proto.Unmarshal(d.Body, &msg); err != nil {
logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
return
}
if err := h.Handle(&msg); err != nil {
logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
return
}
}
}
func subscribe(t *testing.T, ch *amqp.Channel, topic string) <-chan amqp.Delivery {
_, err := ch.QueueDeclare(topic, true, true, true, false, nil)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
err = ch.QueueBind(topic, topic, exchangeName, false, nil)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
clientID := fmt.Sprintf("%s-%s", topic, clientID)
msgs, err := ch.Consume(topic, clientID, true, false, false, false, nil)
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
return msgs
}
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
c := make(chan os.Signal, 2)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(0)
}()
}
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package tracing provides tracing instrumentation for Magistrala things policies service.
//
// This package provides tracing middleware for Magistrala things policies service.
// It can be used to trace incoming requests and add tracing capabilities to
// Magistrala things policies service.
//
// For more details about tracing instrumentation for Magistrala messaging refer
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
package tracing
@@ -1,54 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// Traced operations.
const publishOP = "publish"
var defaultAttributes = []attribute.KeyValue{
attribute.String("messaging.system", "rabbitmq"),
attribute.String("network.protocol.name", "amqp"),
attribute.String("network.protocol.version", "3.9.20"),
attribute.String("messaging.rabbitmq.destination.routing_key", "magistrala"),
}
var _ messaging.Publisher = (*publisherMiddleware)(nil)
type publisherMiddleware struct {
publisher messaging.Publisher
tracer trace.Tracer
host server.Config
}
func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
pub := &publisherMiddleware{
publisher: publisher,
tracer: tracer,
host: config,
}
return pub
}
func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return pm.publisher.Publish(ctx, topic, msg)
}
func (pm *publisherMiddleware) Close() error {
return pm.publisher.Close()
}
-96
View File
@@ -1,96 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/absmach/magistrala/pkg/messaging"
"github.com/absmach/magistrala/pkg/messaging/tracing"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/trace"
)
// Constants to define different operations to be traced.
const (
subscribeOP = "receive"
unsubscribeOp = "unsubscribe" // This is not specified in the open telemetry spec.
processOp = "process"
)
var _ messaging.PubSub = (*pubsubMiddleware)(nil)
type pubsubMiddleware struct {
publisherMiddleware
pubsub messaging.PubSub
host server.Config
}
// NewPubSub creates a new pubsub middleware that traces pubsub operations.
func NewPubSub(config server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
pb := &pubsubMiddleware{
publisherMiddleware: publisherMiddleware{
publisher: pubsub,
tracer: tracer,
host: config,
},
pubsub: pubsub,
host: config,
}
return pb
}
// Subscribe creates a new subscription and traces the operation.
func (pm *pubsubMiddleware) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
ctx, span := tracing.CreateSpan(ctx, subscribeOP, cfg.ID, cfg.Topic, "", 0, pm.host, trace.SpanKindClient, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
cfg.Handler = &traceHandler{
ctx: ctx,
handler: cfg.Handler,
tracer: pm.tracer,
host: pm.host,
topic: cfg.Topic,
clientID: cfg.ID,
}
return pm.pubsub.Subscribe(ctx, cfg)
}
// Unsubscribe removes an existing subscription and traces the operation.
func (pm *pubsubMiddleware) Unsubscribe(ctx context.Context, id, topic string) error {
ctx, span := tracing.CreateSpan(ctx, unsubscribeOp, id, topic, "", 0, pm.host, trace.SpanKindInternal, pm.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return pm.pubsub.Unsubscribe(ctx, id, topic)
}
// TraceHandler is used to trace the message handling operation.
type traceHandler struct {
ctx context.Context
handler messaging.MessageHandler
tracer trace.Tracer
host server.Config
topic string
clientID string
}
// Handle instruments the message handling operation.
func (h *traceHandler) Handle(msg *messaging.Message) error {
_, span := tracing.CreateSpan(h.ctx, processOp, h.clientID, h.topic, msg.GetSubtopic(), len(msg.GetPayload()), h.host, trace.SpanKindConsumer, h.tracer)
defer span.End()
span.SetAttributes(defaultAttributes...)
return h.handler.Handle(msg)
}
// Cancel cancels the message handling operation.
func (h *traceHandler) Cancel() error {
return h.handler.Cancel()
}
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package tracing provides tracing instrumentation for Magistrala things policies service.
//
// This package provides tracing middleware for Magistrala things policies service.
// It can be used to trace incoming requests and add tracing capabilities to
// Magistrala things policies service.
//
// For more details about tracing instrumentation for Magistrala messaging refer
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
package tracing
-44
View File
@@ -1,44 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"fmt"
"github.com/absmach/magistrala/pkg/server"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var defaultAttributes = []attribute.KeyValue{
attribute.Bool("messaging.destination.anonymous", false),
attribute.String("messaging.destination.template", "channels/{channelID}/messages/*"),
attribute.Bool("messaging.destination.temporary", true),
attribute.String("network.transport", "tcp"),
attribute.String("network.type", "ipv4"),
}
func CreateSpan(ctx context.Context, operation, clientID, topic, subTopic string, msgSize int, cfg server.Config, spanKind trace.SpanKind, tracer trace.Tracer) (context.Context, trace.Span) {
subject := fmt.Sprintf("channels.%s.messages", topic)
if subTopic != "" {
subject = fmt.Sprintf("%s.%s", subject, subTopic)
}
spanName := fmt.Sprintf("%s %s", subject, operation)
kvOpts := []attribute.KeyValue{
attribute.String("messaging.operation", operation),
attribute.String("messaging.client_id", clientID),
attribute.String("messaging.destination.name", subject),
attribute.String("server.address", cfg.Host),
attribute.String("server.socket.port", cfg.Port),
}
if msgSize > 0 {
kvOpts = append(kvOpts, attribute.Int("messaging.message.payload_size_bytes", msgSize))
}
kvOpts = append(kvOpts, defaultAttributes...)
return tracer.Start(ctx, spanName, trace.WithAttributes(kvOpts...), trace.WithSpanKind(spanKind))
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package oauth2 contains the domain concept definitions needed to support
// Magistrala ui service OAuth2 functionality.
package oauth2
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package google contains the domain concept definitions needed to support
// Magistrala services for Google OAuth2 functionality.
package google
-132
View File
@@ -1,132 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package google
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"time"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mgoauth2 "github.com/absmach/magistrala/pkg/oauth2"
uclient "github.com/absmach/magistrala/users"
"golang.org/x/oauth2"
googleoauth2 "golang.org/x/oauth2/google"
)
const (
providerName = "google"
defTimeout = 1 * time.Minute
userInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo?access_token="
tokenInfoURL = "https://oauth2.googleapis.com/tokeninfo?access_token="
)
var scopes = []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
var _ mgoauth2.Provider = (*config)(nil)
type config struct {
config *oauth2.Config
state string
uiRedirectURL string
errorURL string
}
// NewProvider returns a new Google OAuth provider.
func NewProvider(cfg mgoauth2.Config, uiRedirectURL, errorURL string) mgoauth2.Provider {
return &config{
config: &oauth2.Config{
ClientID: cfg.ClientID,
ClientSecret: cfg.ClientSecret,
Endpoint: googleoauth2.Endpoint,
RedirectURL: cfg.RedirectURL,
Scopes: scopes,
},
state: cfg.State,
uiRedirectURL: uiRedirectURL,
errorURL: errorURL,
}
}
func (cfg *config) Name() string {
return providerName
}
func (cfg *config) State() string {
return cfg.state
}
func (cfg *config) RedirectURL() string {
return cfg.uiRedirectURL
}
func (cfg *config) ErrorURL() string {
return cfg.errorURL
}
func (cfg *config) IsEnabled() bool {
return cfg.config.ClientID != "" && cfg.config.ClientSecret != ""
}
func (cfg *config) Exchange(ctx context.Context, code string) (oauth2.Token, error) {
token, err := cfg.config.Exchange(ctx, code)
if err != nil {
return oauth2.Token{}, err
}
return *token, nil
}
func (cfg *config) UserInfo(accessToken string) (uclient.User, error) {
resp, err := http.Get(userInfoURL + url.QueryEscape(accessToken))
if err != nil {
return uclient.User{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return uclient.User{}, svcerr.ErrAuthentication
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return uclient.User{}, err
}
var user struct {
ID string `json:"id"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
Username string `json:"username"`
Email string `json:"email"`
Picture string `json:"picture"`
}
if err := json.Unmarshal(data, &user); err != nil {
return uclient.User{}, err
}
if user.ID == "" || user.FirstName == "" || user.LastName == "" || user.Email == "" {
return uclient.User{}, svcerr.ErrAuthentication
}
client := uclient.User{
ID: user.ID,
FirstName: user.FirstName,
LastName: user.LastName,
Email: user.Email,
Metadata: map[string]interface{}{
"oauth_provider": providerName,
"profile_picture": user.Picture,
},
Status: uclient.EnabledStatus,
}
return client, nil
}
-180
View File
@@ -1,180 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
mock "github.com/stretchr/testify/mock"
users "github.com/absmach/magistrala/users"
xoauth2 "golang.org/x/oauth2"
)
// Provider is an autogenerated mock type for the Provider type
type Provider struct {
mock.Mock
}
// ErrorURL provides a mock function with given fields:
func (_m *Provider) ErrorURL() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ErrorURL")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// Exchange provides a mock function with given fields: ctx, code
func (_m *Provider) Exchange(ctx context.Context, code string) (xoauth2.Token, error) {
ret := _m.Called(ctx, code)
if len(ret) == 0 {
panic("no return value specified for Exchange")
}
var r0 xoauth2.Token
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (xoauth2.Token, error)); ok {
return rf(ctx, code)
}
if rf, ok := ret.Get(0).(func(context.Context, string) xoauth2.Token); ok {
r0 = rf(ctx, code)
} else {
r0 = ret.Get(0).(xoauth2.Token)
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, code)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// IsEnabled provides a mock function with given fields:
func (_m *Provider) IsEnabled() bool {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsEnabled")
}
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Name provides a mock function with given fields:
func (_m *Provider) Name() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Name")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// RedirectURL provides a mock function with given fields:
func (_m *Provider) RedirectURL() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for RedirectURL")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// State provides a mock function with given fields:
func (_m *Provider) State() string {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for State")
}
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// UserInfo provides a mock function with given fields: accessToken
func (_m *Provider) UserInfo(accessToken string) (users.User, error) {
ret := _m.Called(accessToken)
if len(ret) == 0 {
panic("no return value specified for UserInfo")
}
var r0 users.User
var r1 error
if rf, ok := ret.Get(0).(func(string) (users.User, error)); ok {
return rf(accessToken)
}
if rf, ok := ret.Get(0).(func(string) users.User); ok {
r0 = rf(accessToken)
} else {
r0 = ret.Get(0).(users.User)
}
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(accessToken)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewProvider creates a new instance of Provider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewProvider(t interface {
mock.TestingT
Cleanup(func())
}) *Provider {
mock := &Provider{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-46
View File
@@ -1,46 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"context"
"github.com/absmach/magistrala/users"
"golang.org/x/oauth2"
)
// Config is the configuration for the OAuth2 provider.
type Config struct {
ClientID string `env:"CLIENT_ID" envDefault:""`
ClientSecret string `env:"CLIENT_SECRET" envDefault:""`
State string `env:"STATE" envDefault:""`
RedirectURL string `env:"REDIRECT_URL" envDefault:""`
}
// Provider is an interface that provides the OAuth2 flow for a specific provider
// (e.g. Google, GitHub, etc.)
//
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Abstract Machines"
type Provider interface {
// Name returns the name of the OAuth2 provider.
Name() string
// State returns the current state for the OAuth2 flow.
State() string
// RedirectURL returns the URL to redirect the user to after completing the OAuth2 flow.
RedirectURL() string
// ErrorURL returns the URL to redirect the user to in case of an error during the OAuth2 flow.
ErrorURL() string
// IsEnabled checks if the OAuth2 provider is enabled.
IsEnabled() bool
// Exchange converts an authorization code into a token.
Exchange(ctx context.Context, code string) (oauth2.Token, error)
// UserInfo retrieves the user's information using the access token.
UserInfo(accessToken string) (users.User, error)
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package policies contains Magistrala policy definitions.
package policies
-64
View File
@@ -1,64 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package policies
import (
"context"
)
const (
TokenKind = "token"
GroupsKind = "groups"
NewGroupKind = "new_group"
ChannelsKind = "channels"
NewChannelKind = "new_channel"
ThingsKind = "things"
NewThingKind = "new_thing"
UsersKind = "users"
DomainsKind = "domains"
PlatformKind = "platform"
)
const (
GroupType = "group"
ThingType = "thing"
UserType = "user"
DomainType = "domain"
PlatformType = "platform"
)
const (
AdministratorRelation = "administrator"
EditorRelation = "editor"
ContributorRelation = "contributor"
MemberRelation = "member"
DomainRelation = "domain"
ParentGroupRelation = "parent_group"
RoleGroupRelation = "role_group"
GroupRelation = "group"
PlatformRelation = "platform"
GuestRelation = "guest"
)
const (
AdminPermission = "admin"
DeletePermission = "delete"
EditPermission = "edit"
ViewPermission = "view"
MembershipPermission = "membership"
SharePermission = "share"
PublishPermission = "publish"
SubscribePermission = "subscribe"
CreatePermission = "create"
)
const MagistralaObject = "magistrala"
//go:generate mockery --name Evaluator --output=./mocks --filename evaluator.go --quiet --note "Copyright (c) Abstract Machines"
type Evaluator interface {
// CheckPolicy checks if the subject has a relation on the object.
// It returns a non-nil error if the subject has no relation on
// the object (which simply means the operation is denied).
CheckPolicy(ctx context.Context, pr Policy) error
}
-49
View File
@@ -1,49 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
policies "github.com/absmach/magistrala/pkg/policies"
mock "github.com/stretchr/testify/mock"
)
// Evaluator is an autogenerated mock type for the Evaluator type
type Evaluator struct {
mock.Mock
}
// CheckPolicy provides a mock function with given fields: ctx, pr
func (_m *Evaluator) CheckPolicy(ctx context.Context, pr policies.Policy) error {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for CheckPolicy")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewEvaluator creates a new instance of Evaluator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewEvaluator(t interface {
mock.TestingT
Cleanup(func())
}) *Evaluator {
mock := &Evaluator{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-301
View File
@@ -1,301 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
policies "github.com/absmach/magistrala/pkg/policies"
mock "github.com/stretchr/testify/mock"
)
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
// AddPolicies provides a mock function with given fields: ctx, prs
func (_m *Service) AddPolicies(ctx context.Context, prs []policies.Policy) error {
ret := _m.Called(ctx, prs)
if len(ret) == 0 {
panic("no return value specified for AddPolicies")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []policies.Policy) error); ok {
r0 = rf(ctx, prs)
} else {
r0 = ret.Error(0)
}
return r0
}
// AddPolicy provides a mock function with given fields: ctx, pr
func (_m *Service) AddPolicy(ctx context.Context, pr policies.Policy) error {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for AddPolicy")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Error(0)
}
return r0
}
// CountObjects provides a mock function with given fields: ctx, pr
func (_m *Service) CountObjects(ctx context.Context, pr policies.Policy) (uint64, error) {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for CountObjects")
}
var r0 uint64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (uint64, error)); ok {
return rf(ctx, pr)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) uint64); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
r1 = rf(ctx, pr)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CountSubjects provides a mock function with given fields: ctx, pr
func (_m *Service) CountSubjects(ctx context.Context, pr policies.Policy) (uint64, error) {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for CountSubjects")
}
var r0 uint64
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (uint64, error)); ok {
return rf(ctx, pr)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) uint64); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Get(0).(uint64)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
r1 = rf(ctx, pr)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// DeletePolicies provides a mock function with given fields: ctx, prs
func (_m *Service) DeletePolicies(ctx context.Context, prs []policies.Policy) error {
ret := _m.Called(ctx, prs)
if len(ret) == 0 {
panic("no return value specified for DeletePolicies")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, []policies.Policy) error); ok {
r0 = rf(ctx, prs)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeletePolicyFilter provides a mock function with given fields: ctx, pr
func (_m *Service) DeletePolicyFilter(ctx context.Context, pr policies.Policy) error {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for DeletePolicyFilter")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Error(0)
}
return r0
}
// ListAllObjects provides a mock function with given fields: ctx, pr
func (_m *Service) ListAllObjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for ListAllObjects")
}
var r0 policies.PolicyPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (policies.PolicyPage, error)); ok {
return rf(ctx, pr)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) policies.PolicyPage); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Get(0).(policies.PolicyPage)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
r1 = rf(ctx, pr)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListAllSubjects provides a mock function with given fields: ctx, pr
func (_m *Service) ListAllSubjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
ret := _m.Called(ctx, pr)
if len(ret) == 0 {
panic("no return value specified for ListAllSubjects")
}
var r0 policies.PolicyPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (policies.PolicyPage, error)); ok {
return rf(ctx, pr)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) policies.PolicyPage); ok {
r0 = rf(ctx, pr)
} else {
r0 = ret.Get(0).(policies.PolicyPage)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
r1 = rf(ctx, pr)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListObjects provides a mock function with given fields: ctx, pr, nextPageToken, limit
func (_m *Service) ListObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
ret := _m.Called(ctx, pr, nextPageToken, limit)
if len(ret) == 0 {
panic("no return value specified for ListObjects")
}
var r0 policies.PolicyPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) (policies.PolicyPage, error)); ok {
return rf(ctx, pr, nextPageToken, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) policies.PolicyPage); ok {
r0 = rf(ctx, pr, nextPageToken, limit)
} else {
r0 = ret.Get(0).(policies.PolicyPage)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, string, uint64) error); ok {
r1 = rf(ctx, pr, nextPageToken, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListPermissions provides a mock function with given fields: ctx, pr, permissionsFilter
func (_m *Service) ListPermissions(ctx context.Context, pr policies.Policy, permissionsFilter []string) (policies.Permissions, error) {
ret := _m.Called(ctx, pr, permissionsFilter)
if len(ret) == 0 {
panic("no return value specified for ListPermissions")
}
var r0 policies.Permissions
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, []string) (policies.Permissions, error)); ok {
return rf(ctx, pr, permissionsFilter)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, []string) policies.Permissions); ok {
r0 = rf(ctx, pr, permissionsFilter)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(policies.Permissions)
}
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, []string) error); ok {
r1 = rf(ctx, pr, permissionsFilter)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ListSubjects provides a mock function with given fields: ctx, pr, nextPageToken, limit
func (_m *Service) ListSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
ret := _m.Called(ctx, pr, nextPageToken, limit)
if len(ret) == 0 {
panic("no return value specified for ListSubjects")
}
var r0 policies.PolicyPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) (policies.PolicyPage, error)); ok {
return rf(ctx, pr, nextPageToken, limit)
}
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) policies.PolicyPage); ok {
r0 = rf(ctx, pr, nextPageToken, limit)
} else {
r0 = ret.Get(0).(policies.PolicyPage)
}
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, string, uint64) error); ok {
r1 = rf(ctx, pr, nextPageToken, limit)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-104
View File
@@ -1,104 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package policies
import (
"context"
"encoding/json"
)
type Policy struct {
// Domain contains the domain ID.
Domain string `json:"domain,omitempty"`
// Subject contains the subject ID or Token.
Subject string `json:"subject"`
// SubjectType contains the subject type. Supported subject types are
// platform, group, domain, thing, users.
SubjectType string `json:"subject_type"`
// SubjectKind contains the subject kind. Supported subject kinds are
// token, users, platform, things, channels, groups, domain.
SubjectKind string `json:"subject_kind"`
// SubjectRelation contains subject relations.
SubjectRelation string `json:"subject_relation,omitempty"`
// Object contains the object ID.
Object string `json:"object"`
// ObjectKind contains the object kind. Supported object kinds are
// users, platform, things, channels, groups, domain.
ObjectKind string `json:"object_kind"`
// ObjectType contains the object type. Supported object types are
// platform, group, domain, thing, users.
ObjectType string `json:"object_type"`
// Relation contains the relation. Supported relations are administrator, editor, contributor, member, guest, parent_group,group,domain.
Relation string `json:"relation,omitempty"`
// Permission contains the permission. Supported permissions are admin, delete, edit, share, view,
// membership, create, admin_only, edit_only, view_only, membership_only, ext_admin, ext_edit, ext_view.
Permission string `json:"permission,omitempty"`
}
func (pr Policy) String() string {
data, err := json.Marshal(pr)
if err != nil {
return ""
}
return string(data)
}
type PolicyPage struct {
Policies []string
NextPageToken string
}
type Permissions []string
// PolicyService facilitates the communication to authorization
// services and implements Authz functionalities for spicedb
//
//go:generate mockery --name Service --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// AddPolicy creates a policy for the given subject, so that, after
// AddPolicy, `subject` has a `relation` on `object`. Returns a non-nil
// error in case of failures.
AddPolicy(ctx context.Context, pr Policy) error
// AddPolicies adds new policies for given subjects. This method is
// only allowed to use as an admin.
AddPolicies(ctx context.Context, prs []Policy) error
// DeletePolicyFilter removes policy for given policy filter request.
DeletePolicyFilter(ctx context.Context, pr Policy) error
// DeletePolicies deletes policies for given subjects. This method is
// only allowed to use as an admin.
DeletePolicies(ctx context.Context, prs []Policy) error
// ListObjects lists policies based on the given Policy structure.
ListObjects(ctx context.Context, pr Policy, nextPageToken string, limit uint64) (PolicyPage, error)
// ListAllObjects lists all policies based on the given Policy structure.
ListAllObjects(ctx context.Context, pr Policy) (PolicyPage, error)
// CountObjects count policies based on the given Policy structure.
CountObjects(ctx context.Context, pr Policy) (uint64, error)
// ListSubjects lists subjects based on the given Policy structure.
ListSubjects(ctx context.Context, pr Policy, nextPageToken string, limit uint64) (PolicyPage, error)
// ListAllSubjects lists all subjects based on the given Policy structure.
ListAllSubjects(ctx context.Context, pr Policy) (PolicyPage, error)
// CountSubjects count policies based on the given Policy structure.
CountSubjects(ctx context.Context, pr Policy) (uint64, error)
// ListPermissions lists permission betweeen given subject and object .
ListPermissions(ctx context.Context, pr Policy, permissionsFilter []string) (Permissions, error)
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package server contains the HTTP, gRPC and CoAP server implementation.
package spicedb
-64
View File
@@ -1,64 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package spicedb
import (
"context"
"log/slog"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/policies"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
)
type policyEvaluator struct {
client *authzed.ClientWithExperimental
permissionClient v1.PermissionsServiceClient
logger *slog.Logger
}
func NewPolicyEvaluator(client *authzed.ClientWithExperimental, logger *slog.Logger) policies.Evaluator {
return &policyEvaluator{
client: client,
permissionClient: client.PermissionsServiceClient,
logger: logger,
}
}
func (pe *policyEvaluator) CheckPolicy(ctx context.Context, pr policies.Policy) error {
checkReq := v1.CheckPermissionRequest{
// FullyConsistent means little caching will be available, which means performance will suffer.
// Only use if a ZedToken is not available or absolutely latest information is required.
// If we want to avoid FullyConsistent and to improve the performance of spicedb, then we need to cache the ZEDTOKEN whenever RELATIONS is created or updated.
// Instead of using FullyConsistent we need to use Consistency_AtLeastAsFresh, code looks like below one.
// Consistency: &v1.Consistency{
// Requirement: &v1.Consistency_AtLeastAsFresh{
// AtLeastAsFresh: getRelationTupleZedTokenFromCache() ,
// }
// },
// Reference: https://authzed.com/docs/reference/api-consistency
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Permission: pr.Permission,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
}
resp, err := pe.permissionClient.CheckPermission(ctx, &checkReq)
if err != nil {
return handleSpicedbError(err)
}
if resp.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
return nil
}
if reason, ok := v1.CheckPermissionResponse_Permissionship_name[int32(resp.Permissionship)]; ok {
return errors.Wrap(svcerr.ErrAuthorization, errors.New(reason))
}
return svcerr.ErrAuthorization
}
-950
View File
@@ -1,950 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package spicedb
import (
"context"
"fmt"
"io"
"log/slog"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/policies"
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
"github.com/authzed/authzed-go/v1"
gstatus "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const defRetrieveAllLimit = 1000
var (
errInvalidSubject = errors.New("invalid subject kind")
errAddPolicies = errors.New("failed to add policies")
errRetrievePolicies = errors.New("failed to retrieve policies")
errRemovePolicies = errors.New("failed to remove the policies")
errNoPolicies = errors.New("no policies provided")
errInternal = errors.New("spicedb internal error")
errPlatform = errors.New("invalid platform id")
)
var (
defThingsFilterPermissions = []string{
policies.AdminPermission,
policies.DeletePermission,
policies.EditPermission,
policies.ViewPermission,
policies.SharePermission,
policies.PublishPermission,
policies.SubscribePermission,
}
defGroupsFilterPermissions = []string{
policies.AdminPermission,
policies.DeletePermission,
policies.EditPermission,
policies.ViewPermission,
policies.MembershipPermission,
policies.SharePermission,
}
defDomainsFilterPermissions = []string{
policies.AdminPermission,
policies.EditPermission,
policies.ViewPermission,
policies.MembershipPermission,
policies.SharePermission,
}
defPlatformFilterPermissions = []string{
policies.AdminPermission,
policies.MembershipPermission,
}
)
type policyService struct {
client *authzed.ClientWithExperimental
permissionClient v1.PermissionsServiceClient
logger *slog.Logger
}
func NewPolicyService(client *authzed.ClientWithExperimental, logger *slog.Logger) policies.Service {
return &policyService{
client: client,
permissionClient: client.PermissionsServiceClient,
logger: logger,
}
}
func (ps *policyService) AddPolicy(ctx context.Context, pr policies.Policy) error {
if err := ps.policyValidation(pr); err != nil {
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
}
precond, err := ps.addPolicyPreCondition(ctx, pr)
if err != nil {
return err
}
updates := []*v1.RelationshipUpdate{
{
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
Relationship: &v1.Relationship{
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Relation: pr.Relation,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
},
},
}
_, err = ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: precond})
if err != nil {
return errors.Wrap(errAddPolicies, handleSpicedbError(err))
}
return nil
}
func (ps *policyService) AddPolicies(ctx context.Context, prs []policies.Policy) error {
updates := []*v1.RelationshipUpdate{}
var preconds []*v1.Precondition
for _, pr := range prs {
if err := ps.policyValidation(pr); err != nil {
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
}
precond, err := ps.addPolicyPreCondition(ctx, pr)
if err != nil {
return err
}
preconds = append(preconds, precond...)
updates = append(updates, &v1.RelationshipUpdate{
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
Relationship: &v1.Relationship{
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Relation: pr.Relation,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
},
})
}
if len(updates) == 0 {
return errors.Wrap(errors.ErrMalformedEntity, errNoPolicies)
}
_, err := ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: preconds})
if err != nil {
return errors.Wrap(errAddPolicies, handleSpicedbError(err))
}
return nil
}
func (ps *policyService) DeletePolicyFilter(ctx context.Context, pr policies.Policy) error {
req := &v1.DeleteRelationshipsRequest{
RelationshipFilter: &v1.RelationshipFilter{
ResourceType: pr.ObjectType,
OptionalResourceId: pr.Object,
},
}
if pr.Relation != "" {
req.RelationshipFilter.OptionalRelation = pr.Relation
}
if pr.SubjectType != "" {
req.RelationshipFilter.OptionalSubjectFilter = &v1.SubjectFilter{
SubjectType: pr.SubjectType,
}
if pr.Subject != "" {
req.RelationshipFilter.OptionalSubjectFilter.OptionalSubjectId = pr.Subject
}
if pr.SubjectRelation != "" {
req.RelationshipFilter.OptionalSubjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{
Relation: pr.SubjectRelation,
}
}
}
if _, err := ps.permissionClient.DeleteRelationships(ctx, req); err != nil {
return errors.Wrap(errRemovePolicies, handleSpicedbError(err))
}
return nil
}
func (ps *policyService) DeletePolicies(ctx context.Context, prs []policies.Policy) error {
updates := []*v1.RelationshipUpdate{}
for _, pr := range prs {
if err := ps.policyValidation(pr); err != nil {
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
}
updates = append(updates, &v1.RelationshipUpdate{
Operation: v1.RelationshipUpdate_OPERATION_DELETE,
Relationship: &v1.Relationship{
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Relation: pr.Relation,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
},
})
}
if len(updates) == 0 {
return errors.Wrap(errors.ErrMalformedEntity, errNoPolicies)
}
_, err := ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates})
if err != nil {
return errors.Wrap(errRemovePolicies, handleSpicedbError(err))
}
return nil
}
func (ps *policyService) ListObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
if limit <= 0 {
limit = 100
}
res, npt, err := ps.retrieveObjects(ctx, pr, nextPageToken, limit)
if err != nil {
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
var page policies.PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Object)
}
page.NextPageToken = npt
return page, nil
}
func (ps *policyService) ListAllObjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
res, err := ps.retrieveAllObjects(ctx, pr)
if err != nil {
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
var page policies.PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Object)
}
return page, nil
}
func (ps *policyService) CountObjects(ctx context.Context, pr policies.Policy) (uint64, error) {
var count uint64
nextPageToken := ""
for {
relationTuples, npt, err := ps.retrieveObjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
if err != nil {
return count, err
}
count = count + uint64(len(relationTuples))
if npt == "" {
break
}
nextPageToken = npt
}
return count, nil
}
func (ps *policyService) ListSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
if limit <= 0 {
limit = 100
}
res, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, limit)
if err != nil {
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
var page policies.PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Subject)
}
page.NextPageToken = npt
return page, nil
}
func (ps *policyService) ListAllSubjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
res, err := ps.retrieveAllSubjects(ctx, pr)
if err != nil {
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
var page policies.PolicyPage
for _, tuple := range res {
page.Policies = append(page.Policies, tuple.Subject)
}
return page, nil
}
func (ps *policyService) CountSubjects(ctx context.Context, pr policies.Policy) (uint64, error) {
var count uint64
nextPageToken := ""
for {
relationTuples, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
if err != nil {
return count, err
}
count = count + uint64(len(relationTuples))
if npt == "" {
break
}
nextPageToken = npt
}
return count, nil
}
func (ps *policyService) ListPermissions(ctx context.Context, pr policies.Policy, permissionsFilter []string) (policies.Permissions, error) {
if len(permissionsFilter) == 0 {
switch pr.ObjectType {
case policies.ThingType:
permissionsFilter = defThingsFilterPermissions
case policies.GroupType:
permissionsFilter = defGroupsFilterPermissions
case policies.PlatformType:
permissionsFilter = defPlatformFilterPermissions
case policies.DomainType:
permissionsFilter = defDomainsFilterPermissions
default:
return nil, svcerr.ErrMalformedEntity
}
}
pers, err := ps.retrievePermissions(ctx, pr, permissionsFilter)
if err != nil {
return []string{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return pers, nil
}
func (ps *policyService) policyValidation(pr policies.Policy) error {
if pr.ObjectType == policies.PlatformType && pr.Object != policies.MagistralaObject {
return errPlatform
}
return nil
}
func (ps *policyService) addPolicyPreCondition(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
// Checks are required for following ( -> means adding)
// 1.) user -> group (both user groups and channels)
// 2.) user -> thing
// 3.) group -> group (both for adding parent_group and channels)
// 4.) group (channel) -> thing
// 5.) user -> domain
switch {
// 1.) user -> group (both user groups and channels)
// Checks :
// - USER with ANY RELATION to DOMAIN
// - GROUP with DOMAIN RELATION to DOMAIN
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.GroupType:
return ps.userGroupPreConditions(ctx, pr)
// 2.) user -> thing
// Checks :
// - USER with ANY RELATION to DOMAIN
// - THING with DOMAIN RELATION to DOMAIN
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.ThingType:
return ps.userThingPreConditions(ctx, pr)
// 3.) group -> group (both for adding parent_group and channels)
// Checks :
// - CHILD_GROUP with out PARENT_GROUP RELATION with any GROUP
case pr.SubjectType == policies.GroupType && pr.ObjectType == policies.GroupType:
return groupPreConditions(pr)
// 4.) group (channel) -> thing
// Checks :
// - GROUP (channel) with DOMAIN RELATION to DOMAIN
// - NO GROUP should not have PARENT_GROUP RELATION with GROUP (channel)
// - THING with DOMAIN RELATION to DOMAIN
case pr.SubjectType == policies.GroupType && pr.ObjectType == policies.ThingType:
return channelThingPreCondition(pr)
// 5.) user -> domain
// Checks :
// - User doesn't have any relation with domain
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.DomainType:
return ps.userDomainPreConditions(ctx, pr)
// Check thing and group not belongs to other domain before adding to domain
case pr.SubjectType == policies.DomainType && pr.Relation == policies.DomainRelation && (pr.ObjectType == policies.ThingType || pr.ObjectType == policies.GroupType):
preconds := []*v1.Precondition{
{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: pr.ObjectType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
},
},
},
}
return preconds, nil
}
return nil, nil
}
func (ps *policyService) userGroupPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
var preconds []*v1.Precondition
// user should not have any relation with group
preconds = append(preconds, &v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.UserType,
OptionalSubjectId: pr.Subject,
},
},
})
isSuperAdmin := false
if err := ps.checkPolicy(ctx, policies.Policy{
Subject: pr.Subject,
SubjectType: pr.SubjectType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}); err == nil {
isSuperAdmin = true
}
if !isSuperAdmin {
preconds = append(preconds, &v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.DomainType,
OptionalResourceId: pr.Domain,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.UserType,
OptionalSubjectId: pr.Subject,
},
},
})
}
switch {
case pr.ObjectKind == policies.NewGroupKind || pr.ObjectKind == policies.NewChannelKind:
preconds = append(preconds,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
},
},
},
)
default:
preconds = append(preconds,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
)
}
return preconds, nil
}
func (ps *policyService) userThingPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
var preconds []*v1.Precondition
// user should not have any relation with thing
preconds = append(preconds, &v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.ThingType,
OptionalResourceId: pr.Object,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.UserType,
OptionalSubjectId: pr.Subject,
},
},
})
isSuperAdmin := false
if err := ps.checkPolicy(ctx, policies.Policy{
Subject: pr.Subject,
SubjectType: pr.SubjectType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}); err == nil {
isSuperAdmin = true
}
if !isSuperAdmin {
preconds = append(preconds, &v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.DomainType,
OptionalResourceId: pr.Domain,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.UserType,
OptionalSubjectId: pr.Subject,
},
},
})
}
switch {
// For New thing
// - THING without DOMAIN RELATION to ANY DOMAIN
case pr.ObjectKind == policies.NewThingKind:
preconds = append(preconds,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.ThingType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
},
},
},
)
default:
// For existing thing
// - THING without DOMAIN RELATION to ANY DOMAIN
preconds = append(preconds,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.ThingType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
)
}
return preconds, nil
}
func (ps *policyService) userDomainPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
var preconds []*v1.Precondition
if err := ps.checkPolicy(ctx, policies.Policy{
Subject: pr.Subject,
SubjectType: pr.SubjectType,
Permission: policies.AdminPermission,
Object: policies.MagistralaObject,
ObjectType: policies.PlatformType,
}); err == nil {
return preconds, fmt.Errorf("use already exists in domain")
}
// user should not have any relation with domain.
preconds = append(preconds, &v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.DomainType,
OptionalResourceId: pr.Object,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.UserType,
OptionalSubjectId: pr.Subject,
},
},
})
return preconds, nil
}
func (ps *policyService) checkPolicy(ctx context.Context, pr policies.Policy) error {
checkReq := v1.CheckPermissionRequest{
// FullyConsistent means little caching will be available, which means performance will suffer.
// Only use if a ZedToken is not available or absolutely latest information is required.
// If we want to avoid FullyConsistent and to improve the performance of spicedb, then we need to cache the ZEDTOKEN whenever RELATIONS is created or updated.
// Instead of using FullyConsistent we need to use Consistency_AtLeastAsFresh, code looks like below one.
// Consistency: &v1.Consistency{
// Requirement: &v1.Consistency_AtLeastAsFresh{
// AtLeastAsFresh: getRelationTupleZedTokenFromCache() ,
// }
// },
// Reference: https://authzed.com/docs/reference/api-consistency
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Permission: pr.Permission,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
}
resp, err := ps.permissionClient.CheckPermission(ctx, &checkReq)
if err != nil {
return handleSpicedbError(err)
}
if resp.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
return nil
}
if reason, ok := v1.CheckPermissionResponse_Permissionship_name[int32(resp.Permissionship)]; ok {
return errors.Wrap(svcerr.ErrAuthorization, errors.New(reason))
}
return svcerr.ErrAuthorization
}
func (ps *policyService) retrieveObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) ([]policies.Policy, string, error) {
resourceReq := &v1.LookupResourcesRequest{
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
ResourceObjectType: pr.ObjectType,
Permission: pr.Permission,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
OptionalLimit: uint32(limit),
}
if nextPageToken != "" {
resourceReq.OptionalCursor = &v1.Cursor{Token: nextPageToken}
}
stream, err := ps.permissionClient.LookupResources(ctx, resourceReq)
if err != nil {
return nil, "", errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
resources := []*v1.LookupResourcesResponse{}
var token string
for {
resp, err := stream.Recv()
switch err {
case nil:
resources = append(resources, resp)
case io.EOF:
if len(resources) > 0 && resources[len(resources)-1].AfterResultCursor != nil {
token = resources[len(resources)-1].AfterResultCursor.Token
}
return objectsToAuthPolicies(resources), token, nil
default:
if len(resources) > 0 && resources[len(resources)-1].AfterResultCursor != nil {
token = resources[len(resources)-1].AfterResultCursor.Token
}
return []policies.Policy{}, token, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
}
}
func (ps *policyService) retrieveAllObjects(ctx context.Context, pr policies.Policy) ([]policies.Policy, error) {
resourceReq := &v1.LookupResourcesRequest{
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
ResourceObjectType: pr.ObjectType,
Permission: pr.Permission,
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
}
stream, err := ps.permissionClient.LookupResources(ctx, resourceReq)
if err != nil {
return nil, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
tuples := []policies.Policy{}
for {
resp, err := stream.Recv()
switch {
case errors.Contains(err, io.EOF):
return tuples, nil
case err != nil:
return tuples, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
default:
tuples = append(tuples, policies.Policy{Object: resp.ResourceObjectId})
}
}
}
func (ps *policyService) retrieveSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) ([]policies.Policy, string, error) {
subjectsReq := v1.LookupSubjectsRequest{
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
Permission: pr.Permission,
SubjectObjectType: pr.SubjectType,
OptionalSubjectRelation: pr.SubjectRelation,
OptionalConcreteLimit: uint32(limit),
WildcardOption: v1.LookupSubjectsRequest_WILDCARD_OPTION_INCLUDE_WILDCARDS,
}
if nextPageToken != "" {
subjectsReq.OptionalCursor = &v1.Cursor{Token: nextPageToken}
}
stream, err := ps.permissionClient.LookupSubjects(ctx, &subjectsReq)
if err != nil {
return nil, "", errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
subjects := []*v1.LookupSubjectsResponse{}
var token string
for {
resp, err := stream.Recv()
switch err {
case nil:
subjects = append(subjects, resp)
case io.EOF:
if len(subjects) > 0 && subjects[len(subjects)-1].AfterResultCursor != nil {
token = subjects[len(subjects)-1].AfterResultCursor.Token
}
return subjectsToAuthPolicies(subjects), token, nil
default:
if len(subjects) > 0 && subjects[len(subjects)-1].AfterResultCursor != nil {
token = subjects[len(subjects)-1].AfterResultCursor.Token
}
return []policies.Policy{}, token, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
}
}
func (ps *policyService) retrieveAllSubjects(ctx context.Context, pr policies.Policy) ([]policies.Policy, error) {
var tuples []policies.Policy
nextPageToken := ""
for i := 0; ; i++ {
relationTuples, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
if err != nil {
return tuples, err
}
tuples = append(tuples, relationTuples...)
if npt == "" || (len(tuples) < defRetrieveAllLimit) {
break
}
nextPageToken = npt
}
return tuples, nil
}
func (ps *policyService) retrievePermissions(ctx context.Context, pr policies.Policy, filterPermission []string) (policies.Permissions, error) {
var permissionChecks []*v1.CheckBulkPermissionsRequestItem
for _, fp := range filterPermission {
permissionChecks = append(permissionChecks, &v1.CheckBulkPermissionsRequestItem{
Resource: &v1.ObjectReference{
ObjectType: pr.ObjectType,
ObjectId: pr.Object,
},
Permission: fp,
Subject: &v1.SubjectReference{
Object: &v1.ObjectReference{
ObjectType: pr.SubjectType,
ObjectId: pr.Subject,
},
OptionalRelation: pr.SubjectRelation,
},
})
}
resp, err := ps.client.PermissionsServiceClient.CheckBulkPermissions(ctx, &v1.CheckBulkPermissionsRequest{
Consistency: &v1.Consistency{
Requirement: &v1.Consistency_FullyConsistent{
FullyConsistent: true,
},
},
Items: permissionChecks,
})
if err != nil {
return policies.Permissions{}, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
}
permissions := []string{}
for _, pair := range resp.Pairs {
if pair.GetError() != nil {
s := pair.GetError()
return policies.Permissions{}, errors.Wrap(errRetrievePolicies, convertGRPCStatusToError(convertToGrpcStatus(s)))
}
item := pair.GetItem()
req := pair.GetRequest()
if item != nil && req != nil && item.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
permissions = append(permissions, req.GetPermission())
}
}
return permissions, nil
}
func groupPreConditions(pr policies.Policy) ([]*v1.Precondition, error) {
// - PARENT_GROUP (subject) with DOMAIN RELATION to DOMAIN
precond := []*v1.Precondition{
{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Subject,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
}
if pr.ObjectKind != policies.ChannelsKind {
precond = append(precond,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.ParentGroupRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.GroupType,
},
},
},
)
}
switch {
// - NEW CHILD_GROUP (object) with out DOMAIN RELATION to ANY DOMAIN
case pr.ObjectType == policies.GroupType && pr.ObjectKind == policies.NewGroupKind:
precond = append(precond,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
},
},
},
)
default:
// - CHILD_GROUP (object) with DOMAIN RELATION to DOMAIN
precond = append(precond,
&v1.Precondition{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
)
}
return precond, nil
}
func channelThingPreCondition(pr policies.Policy) ([]*v1.Precondition, error) {
if pr.SubjectKind != policies.ChannelsKind {
return nil, errors.Wrap(errors.ErrMalformedEntity, errInvalidSubject)
}
precond := []*v1.Precondition{
{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalResourceId: pr.Subject,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
{
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.GroupType,
OptionalRelation: policies.ParentGroupRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.GroupType,
OptionalSubjectId: pr.Subject,
},
},
},
{
Operation: v1.Precondition_OPERATION_MUST_MATCH,
Filter: &v1.RelationshipFilter{
ResourceType: policies.ThingType,
OptionalResourceId: pr.Object,
OptionalRelation: policies.DomainRelation,
OptionalSubjectFilter: &v1.SubjectFilter{
SubjectType: policies.DomainType,
OptionalSubjectId: pr.Domain,
},
},
},
}
return precond, nil
}
func objectsToAuthPolicies(objects []*v1.LookupResourcesResponse) []policies.Policy {
var policyList []policies.Policy
for _, obj := range objects {
policyList = append(policyList, policies.Policy{
Object: obj.GetResourceObjectId(),
})
}
return policyList
}
func subjectsToAuthPolicies(subjects []*v1.LookupSubjectsResponse) []policies.Policy {
var policyList []policies.Policy
for _, sub := range subjects {
policyList = append(policyList, policies.Policy{
Subject: sub.Subject.GetSubjectObjectId(),
})
}
return policyList
}
func handleSpicedbError(err error) error {
if st, ok := status.FromError(err); ok {
return convertGRPCStatusToError(st)
}
return err
}
func convertToGrpcStatus(gst *gstatus.Status) *status.Status {
st := status.New(codes.Code(gst.Code), gst.GetMessage())
return st
}
func convertGRPCStatusToError(st *status.Status) error {
switch st.Code() {
case codes.NotFound:
return errors.Wrap(repoerr.ErrNotFound, errors.New(st.Message()))
case codes.InvalidArgument:
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
case codes.AlreadyExists:
return errors.Wrap(repoerr.ErrConflict, errors.New(st.Message()))
case codes.Unauthenticated:
return errors.Wrap(svcerr.ErrAuthentication, errors.New(st.Message()))
case codes.Internal:
return errors.Wrap(errInternal, errors.New(st.Message()))
case codes.OK:
if msg := st.Message(); msg != "" {
return errors.Wrap(errors.ErrUnidentified, errors.New(msg))
}
return nil
case codes.FailedPrecondition:
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
case codes.PermissionDenied:
return errors.Wrap(svcerr.ErrAuthorization, errors.New(st.Message()))
default:
return errors.Wrap(fmt.Errorf("unexpected gRPC status: %s (status code:%v)", st.Code().String(), st.Code()), errors.New(st.Message()))
}
}

Some files were not shown because too many files have changed in this diff Show More