mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 07:20:19 +00:00
NOISSUE - Remove SuperMQ duplicates (#23)
* Update docker-compose to use SuperMQ Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Remove duplicate services Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Update Bootstrap Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Update other services to use SMQ Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Switch config prefix to SMQ Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Remove leftovers Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Remove duplicate interface definitions Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Remove unused actions Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Remove unused API docs Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Resolve linter comments Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> * Fix provision Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com> --------- Signed-off-by: Dusan Borovcanin <borovcanindusan1@gmail.com>
This commit is contained in:
@@ -1,209 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil
|
||||
|
||||
import "github.com/absmach/magistrala/pkg/errors"
|
||||
|
||||
// Errors defined in this file are used by the LoggingErrorEncoder decorator
|
||||
// to distinguish and log API request validation errors and avoid that service
|
||||
// errors are logged twice.
|
||||
var (
|
||||
// ErrValidation indicates that an error was returned by the API.
|
||||
ErrValidation = errors.New("something went wrong with the request")
|
||||
|
||||
// ErrBearerToken indicates missing or invalid bearer user token.
|
||||
ErrBearerToken = errors.New("missing or invalid bearer user token")
|
||||
|
||||
// ErrBearerKey indicates missing or invalid bearer entity key.
|
||||
ErrBearerKey = errors.New("missing or invalid bearer entity key")
|
||||
|
||||
// ErrMissingID indicates missing entity ID.
|
||||
ErrMissingID = errors.New("missing entity id")
|
||||
|
||||
// ErrInvalidAuthKey indicates invalid auth key.
|
||||
ErrInvalidAuthKey = errors.New("invalid auth key")
|
||||
|
||||
// ErrInvalidIDFormat indicates an invalid ID format.
|
||||
ErrInvalidIDFormat = errors.New("invalid id format provided")
|
||||
|
||||
// ErrNameSize indicates that name size exceeds the max.
|
||||
ErrNameSize = errors.New("invalid name size")
|
||||
|
||||
// ErrEmailSize indicates that email size exceeds the max.
|
||||
ErrEmailSize = errors.New("invalid email size")
|
||||
|
||||
// ErrInvalidRole indicates that an invalid role.
|
||||
ErrInvalidRole = errors.New("invalid client role")
|
||||
|
||||
// ErrLimitSize indicates that an invalid limit.
|
||||
ErrLimitSize = errors.New("invalid limit size")
|
||||
|
||||
// ErrOffsetSize indicates an invalid offset.
|
||||
ErrOffsetSize = errors.New("invalid offset size")
|
||||
|
||||
// ErrInvalidOrder indicates an invalid list order.
|
||||
ErrInvalidOrder = errors.New("invalid list order provided")
|
||||
|
||||
// ErrInvalidDirection indicates an invalid list direction.
|
||||
ErrInvalidDirection = errors.New("invalid list direction provided")
|
||||
|
||||
// ErrInvalidMemberKind indicates an invalid member kind.
|
||||
ErrInvalidMemberKind = errors.New("invalid member kind")
|
||||
|
||||
// ErrEmptyList indicates that entity data is empty.
|
||||
ErrEmptyList = errors.New("empty list provided")
|
||||
|
||||
// ErrMalformedPolicy indicates that policies are malformed.
|
||||
ErrMalformedPolicy = errors.New("malformed policy")
|
||||
|
||||
// ErrMissingPolicySub indicates that policies are subject.
|
||||
ErrMissingPolicySub = errors.New("malformed policy subject")
|
||||
|
||||
// ErrMissingPolicyObj indicates missing policies object.
|
||||
ErrMissingPolicyObj = errors.New("malformed policy object")
|
||||
|
||||
// ErrMalformedPolicyAct indicates missing policies action.
|
||||
ErrMalformedPolicyAct = errors.New("malformed policy action")
|
||||
|
||||
// ErrMissingPolicyEntityType indicates missing policies entity type.
|
||||
ErrMissingPolicyEntityType = errors.New("missing policy entity type")
|
||||
|
||||
// ErrMalformedPolicyPer indicates missing policies relation.
|
||||
ErrMalformedPolicyPer = errors.New("malformed policy permission")
|
||||
|
||||
// ErrMissingCertData indicates missing cert data (ttl).
|
||||
ErrMissingCertData = errors.New("missing certificate data")
|
||||
|
||||
// ErrInvalidCertData indicates invalid cert data (ttl).
|
||||
ErrInvalidCertData = errors.New("invalid certificate data")
|
||||
|
||||
// ErrInvalidTopic indicates an invalid subscription topic.
|
||||
ErrInvalidTopic = errors.New("invalid Subscription topic")
|
||||
|
||||
// ErrInvalidContact indicates an invalid subscription contract.
|
||||
ErrInvalidContact = errors.New("invalid Subscription contact")
|
||||
|
||||
// ErrMissingEmail indicates missing email.
|
||||
ErrMissingEmail = errors.New("missing email")
|
||||
|
||||
// ErrInvalidEmail indicates missing email.
|
||||
ErrInvalidEmail = errors.New("invalid email")
|
||||
|
||||
// ErrMissingHost indicates missing host.
|
||||
ErrMissingHost = errors.New("missing host")
|
||||
|
||||
// ErrMissingPass indicates missing password.
|
||||
ErrMissingPass = errors.New("missing password")
|
||||
|
||||
// ErrMissingConfPass indicates missing conf password.
|
||||
ErrMissingConfPass = errors.New("missing conf password")
|
||||
|
||||
// ErrInvalidResetPass indicates an invalid reset password.
|
||||
ErrInvalidResetPass = errors.New("invalid reset password")
|
||||
|
||||
// ErrInvalidComparator indicates an invalid comparator.
|
||||
ErrInvalidComparator = errors.New("invalid comparator")
|
||||
|
||||
// ErrMissingMemberType indicates missing group member type.
|
||||
ErrMissingMemberType = errors.New("missing group member type")
|
||||
|
||||
// ErrMissingMemberKind indicates missing group member kind.
|
||||
ErrMissingMemberKind = errors.New("missing group member kind")
|
||||
|
||||
// ErrMissingRelation indicates missing relation.
|
||||
ErrMissingRelation = errors.New("missing relation")
|
||||
|
||||
// ErrInvalidRelation indicates an invalid relation.
|
||||
ErrInvalidRelation = errors.New("invalid relation")
|
||||
|
||||
// ErrInvalidAPIKey indicates an invalid API key type.
|
||||
ErrInvalidAPIKey = errors.New("invalid api key type")
|
||||
|
||||
// ErrBootstrapState indicates an invalid bootstrap state.
|
||||
ErrBootstrapState = errors.New("invalid bootstrap state")
|
||||
|
||||
// ErrInvitationState indicates an invalid invitation state.
|
||||
ErrInvitationState = errors.New("invalid invitation state")
|
||||
|
||||
// ErrMissingIdentity indicates missing entity Identity.
|
||||
ErrMissingIdentity = errors.New("missing entity identity")
|
||||
|
||||
// ErrMissingSecret indicates missing secret.
|
||||
ErrMissingSecret = errors.New("missing secret")
|
||||
|
||||
// ErrPasswordFormat indicates weak password.
|
||||
ErrPasswordFormat = errors.New("password does not meet the requirements")
|
||||
|
||||
// ErrMissingName indicates missing identity name.
|
||||
ErrMissingName = errors.New("missing identity name")
|
||||
|
||||
// ErrMissingName indicates missing alias.
|
||||
ErrMissingAlias = errors.New("missing alias")
|
||||
|
||||
// ErrInvalidLevel indicates an invalid group level.
|
||||
ErrInvalidLevel = errors.New("invalid group level (should be between 0 and 5)")
|
||||
|
||||
// ErrNotFoundParam indicates that the parameter was not found in the query.
|
||||
ErrNotFoundParam = errors.New("parameter not found in the query")
|
||||
|
||||
// ErrInvalidQueryParams indicates invalid query parameters.
|
||||
ErrInvalidQueryParams = errors.New("invalid query parameters")
|
||||
|
||||
// ErrInvalidVisibilityType indicates invalid visibility type.
|
||||
ErrInvalidVisibilityType = errors.New("invalid visibility type")
|
||||
|
||||
// ErrUnsupportedContentType indicates unacceptable or lack of Content-Type.
|
||||
ErrUnsupportedContentType = errors.New("unsupported content type")
|
||||
|
||||
// ErrRollbackTx indicates failed to rollback transaction.
|
||||
ErrRollbackTx = errors.New("failed to rollback transaction")
|
||||
|
||||
// ErrInvalidAggregation indicates invalid aggregation value.
|
||||
ErrInvalidAggregation = errors.New("invalid aggregation value")
|
||||
|
||||
// ErrInvalidInterval indicates invalid interval value.
|
||||
ErrInvalidInterval = errors.New("invalid interval value")
|
||||
|
||||
// ErrMissingFrom indicates missing from value.
|
||||
ErrMissingFrom = errors.New("missing from time value")
|
||||
|
||||
// ErrMissingTo indicates missing to value.
|
||||
ErrMissingTo = errors.New("missing to time value")
|
||||
|
||||
// ErrEmptyMessage indicates empty message.
|
||||
ErrEmptyMessage = errors.New("empty message")
|
||||
|
||||
// ErrMissingEntityType indicates missing entity type.
|
||||
ErrMissingEntityType = errors.New("missing entity type")
|
||||
|
||||
// ErrInvalidEntityType indicates invalid entity type.
|
||||
ErrInvalidEntityType = errors.New("invalid entity type")
|
||||
|
||||
// ErrInvalidTimeFormat indicates invalid time format i.e not unix time.
|
||||
ErrInvalidTimeFormat = errors.New("invalid time format use unix time")
|
||||
|
||||
// ErrEmptySearchQuery indicates search query should not be empty.
|
||||
ErrEmptySearchQuery = errors.New("search query must not be empty")
|
||||
|
||||
// ErrLenSearchQuery indicates search query length.
|
||||
ErrLenSearchQuery = errors.New("search query must be at least 3 characters")
|
||||
|
||||
// ErrMissingDomainID indicates missing domainID.
|
||||
ErrMissingDomainID = errors.New("missing domainID")
|
||||
|
||||
// ErrMissingUsername indicates missing user name.
|
||||
ErrMissingUsername = errors.New("missing username")
|
||||
|
||||
// ErrInvalidUsername indicates missing user name.
|
||||
ErrInvalidUsername = errors.New("invalid username")
|
||||
|
||||
// ErrMissingFirstName indicates missing first name.
|
||||
ErrMissingFirstName = errors.New("missing first name")
|
||||
|
||||
// ErrMissingLastName indicates missing last name.
|
||||
ErrMissingLastName = errors.New("missing last name")
|
||||
|
||||
// ErrInvalidProfilePictureURL indicates that the profile picture url is invalid.
|
||||
ErrInvalidProfilePictureURL = errors.New("invalid profile picture url")
|
||||
)
|
||||
@@ -1,10 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil
|
||||
|
||||
// ErrorRes represents the HTTP error response body.
|
||||
type ErrorRes struct {
|
||||
Err string `json:"error"`
|
||||
Msg string `json:"message"`
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BearerPrefix represents the token prefix for Bearer authentication scheme.
|
||||
const BearerPrefix = "Bearer "
|
||||
|
||||
// ThingPrefix represents the key prefix for Thing authentication scheme.
|
||||
const ThingPrefix = "Thing "
|
||||
|
||||
// ExtractBearerToken returns value of the bearer token. If there is no bearer token - an empty value is returned.
|
||||
func ExtractBearerToken(r *http.Request) string {
|
||||
token := r.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(token, BearerPrefix) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(token, BearerPrefix)
|
||||
}
|
||||
|
||||
// ExtractThingKey returns value of the thing key. If there is no thing key - an empty value is returned.
|
||||
func ExtractThingKey(r *http.Request) string {
|
||||
token := r.Header.Get("Authorization")
|
||||
|
||||
if !strings.HasPrefix(token, ThingPrefix) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(token, ThingPrefix)
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/apiutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
request *http.Request
|
||||
token string
|
||||
}{
|
||||
{
|
||||
desc: "valid bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {"Bearer 123"},
|
||||
},
|
||||
},
|
||||
token: "123",
|
||||
},
|
||||
{
|
||||
desc: "invalid bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {"123"},
|
||||
},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
{
|
||||
desc: "empty bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {""},
|
||||
},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
{
|
||||
desc: "empty header",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
token := apiutil.ExtractBearerToken(c.request)
|
||||
assert.Equal(t, c.token, token)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractThingKey(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
request *http.Request
|
||||
token string
|
||||
}{
|
||||
{
|
||||
desc: "valid bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {"Thing 123"},
|
||||
},
|
||||
},
|
||||
token: "123",
|
||||
},
|
||||
{
|
||||
desc: "invalid bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {"123"},
|
||||
},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
{
|
||||
desc: "empty bearer token",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{
|
||||
"Authorization": {""},
|
||||
},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
{
|
||||
desc: "empty header",
|
||||
request: &http.Request{
|
||||
Header: map[string][]string{},
|
||||
},
|
||||
token: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
token := apiutil.ExtractThingKey(c.request)
|
||||
assert.Equal(t, c.token, token)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
)
|
||||
|
||||
// LoggingErrorEncoder is a go-kit error encoder logging decorator.
|
||||
func LoggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder {
|
||||
return func(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
if errors.Contains(err, ErrValidation) {
|
||||
logger.Error(err.Error())
|
||||
}
|
||||
enc(ctx, err, w)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadStringQuery reads the value of string http query parameters for a given key.
|
||||
func ReadStringQuery(r *http.Request, key, def string) (string, error) {
|
||||
vals := r.URL.Query()[key]
|
||||
if len(vals) > 1 {
|
||||
return "", ErrInvalidQueryParams
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
return def, nil
|
||||
}
|
||||
|
||||
return vals[0], nil
|
||||
}
|
||||
|
||||
// ReadMetadataQuery reads the value of json http query parameters for a given key.
|
||||
func ReadMetadataQuery(r *http.Request, key string, def map[string]interface{}) (map[string]interface{}, error) {
|
||||
vals := r.URL.Query()[key]
|
||||
if len(vals) > 1 {
|
||||
return nil, ErrInvalidQueryParams
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
return def, nil
|
||||
}
|
||||
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal([]byte(vals[0]), &m)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ReadBoolQuery reads boolean query parameters in a given http request.
|
||||
func ReadBoolQuery(r *http.Request, key string, def bool) (bool, error) {
|
||||
vals := r.URL.Query()[key]
|
||||
if len(vals) > 1 {
|
||||
return false, ErrInvalidQueryParams
|
||||
}
|
||||
|
||||
if len(vals) == 0 {
|
||||
return def, nil
|
||||
}
|
||||
|
||||
b, err := strconv.ParseBool(vals[0])
|
||||
if err != nil {
|
||||
return false, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
type number interface {
|
||||
int64 | float64 | uint16 | uint64
|
||||
}
|
||||
|
||||
// ReadNumQuery returns a numeric value.
|
||||
func ReadNumQuery[N number](r *http.Request, key string, def N) (N, error) {
|
||||
vals := r.URL.Query()[key]
|
||||
if len(vals) > 1 {
|
||||
return 0, ErrInvalidQueryParams
|
||||
}
|
||||
if len(vals) == 0 {
|
||||
return def, nil
|
||||
}
|
||||
val := vals[0]
|
||||
|
||||
switch any(def).(type) {
|
||||
case int64:
|
||||
v, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
return N(v), nil
|
||||
case uint64:
|
||||
v, err := strconv.ParseUint(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
return N(v), nil
|
||||
case uint16:
|
||||
v, err := strconv.ParseUint(val, 10, 16)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
return N(v), nil
|
||||
case float64:
|
||||
v, err := strconv.ParseFloat(val, 64)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(ErrInvalidQueryParams, err)
|
||||
}
|
||||
return N(v), nil
|
||||
default:
|
||||
return def, nil
|
||||
}
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package apiutil_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/apiutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReadStringQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
url string
|
||||
key string
|
||||
ret string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid string query",
|
||||
url: "http://localhost:8080/?key=test",
|
||||
key: "key",
|
||||
ret: "test",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty string query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
ret: "",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "multiple string query",
|
||||
url: "http://localhost:8080/?key=test&key=random",
|
||||
key: "key",
|
||||
ret: "",
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(c.url)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r := &http.Request{URL: parsedURL}
|
||||
ret, err := apiutil.ReadStringQuery(r, c.key, "")
|
||||
assert.Equal(t, c.err, err)
|
||||
assert.Equal(t, c.ret, ret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMetadataQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
url string
|
||||
key string
|
||||
ret map[string]interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid metadata query",
|
||||
url: "http://localhost:8080/?key={\"test\":\"test\"}",
|
||||
key: "key",
|
||||
ret: map[string]interface{}{"test": "test"},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty metadata query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
ret: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "multiple metadata query",
|
||||
url: "http://localhost:8080/?key={\"test\":\"test\"}&key={\"random\":\"random\"}",
|
||||
key: "key",
|
||||
ret: nil,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "invalid metadata query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
ret: nil,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(c.url)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r := &http.Request{URL: parsedURL}
|
||||
ret, err := apiutil.ReadMetadataQuery(r, c.key, nil)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
|
||||
assert.Equal(t, c.ret, ret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadBoolQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
url string
|
||||
key string
|
||||
ret bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid bool query",
|
||||
url: "http://localhost:8080/?key=true",
|
||||
key: "key",
|
||||
ret: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid bool query",
|
||||
url: "http://localhost:8080/?key=false",
|
||||
key: "key",
|
||||
ret: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "invalid bool query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
ret: false,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "empty bool query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
ret: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "multiple bool query",
|
||||
url: "http://localhost:8080/?key=true&key=false",
|
||||
key: "key",
|
||||
ret: false,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(c.url)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r := &http.Request{URL: parsedURL}
|
||||
ret, err := apiutil.ReadBoolQuery(r, c.key, false)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
|
||||
assert.Equal(t, c.ret, ret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadNumQuery(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
url string
|
||||
key string
|
||||
numType string
|
||||
ret interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid int64 query",
|
||||
url: "http://localhost:8080/?key=123",
|
||||
key: "key",
|
||||
numType: "int64",
|
||||
ret: int64(123),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid float64 query",
|
||||
url: "http://localhost:8080/?key=1.23",
|
||||
key: "key",
|
||||
numType: "float64",
|
||||
ret: float64(1.23),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid uint64 query",
|
||||
url: "http://localhost:8080/?key=123",
|
||||
key: "key",
|
||||
numType: "uint64",
|
||||
ret: uint64(123),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "valid uint16 query",
|
||||
url: "http://localhost:8080/?key=123",
|
||||
key: "key",
|
||||
numType: "uint16",
|
||||
ret: uint16(123),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "invalid int64 query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
numType: "int64",
|
||||
ret: int64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "invalid float64 query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
numType: "float64",
|
||||
ret: float64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "invalid uint64 query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
numType: "uint64",
|
||||
ret: uint64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "invalid uint16 query",
|
||||
url: "http://localhost:8080/?key=abc",
|
||||
key: "key",
|
||||
numType: "uint16",
|
||||
ret: uint16(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "empty int64 query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
numType: "int64",
|
||||
ret: int64(0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty float64 query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
numType: "float64",
|
||||
ret: float64(0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty uint16 query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
numType: "uint16",
|
||||
ret: uint16(0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty uint64 query",
|
||||
url: "http://localhost:8080/",
|
||||
key: "key",
|
||||
numType: "uint64",
|
||||
ret: uint64(0),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "multiple int64 query",
|
||||
url: "http://localhost:8080/?key=123&key=456",
|
||||
key: "key",
|
||||
numType: "int64",
|
||||
ret: int64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "multiple float64 query",
|
||||
url: "http://localhost:8080/?key=1.23&key=4.56",
|
||||
key: "key",
|
||||
numType: "float64",
|
||||
ret: float64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "multiple uint16 query",
|
||||
url: "http://localhost:8080/?key=123&key=456",
|
||||
key: "key",
|
||||
numType: "uint16",
|
||||
ret: uint16(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "multiple uint64 query",
|
||||
url: "http://localhost:8080/?key=123&key=456",
|
||||
key: "key",
|
||||
numType: "uint64",
|
||||
ret: uint64(0),
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(c.url)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r := &http.Request{URL: parsedURL}
|
||||
var ret interface{}
|
||||
switch c.numType {
|
||||
case "int64":
|
||||
ret, err = apiutil.ReadNumQuery[int64](r, c.key, 0)
|
||||
case "float64":
|
||||
ret, err = apiutil.ReadNumQuery[float64](r, c.key, 0)
|
||||
case "uint64":
|
||||
ret, err = apiutil.ReadNumQuery[uint64](r, c.key, 0)
|
||||
case "uint16":
|
||||
ret, err = apiutil.ReadNumQuery[uint16](r, c.key, 0)
|
||||
}
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err))
|
||||
assert.Equal(t, c.ret, ret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggingErrorEncoder(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "error contains ErrValidation",
|
||||
err: errors.Wrap(apiutil.ErrValidation, svcerr.ErrAuthentication),
|
||||
},
|
||||
{
|
||||
desc: "error does not contain ErrValidation",
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
encCalled := false
|
||||
encFunc := func(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
encCalled = true
|
||||
}
|
||||
|
||||
errorEncoder := apiutil.LoggingErrorEncoder(mglog.NewMock(), encFunc)
|
||||
errorEncoder(context.Background(), c.err, httptest.NewRecorder())
|
||||
|
||||
assert.True(t, encCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authn
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
DomainUserID string
|
||||
UserID string
|
||||
DomainID string
|
||||
SuperAdmin bool
|
||||
}
|
||||
|
||||
// Authn is magistrala authentication library.
|
||||
//
|
||||
//go:generate mockery --name Authentication --output=./mocks --filename authn.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Authentication interface {
|
||||
Authenticate(ctx context.Context, token string) (Session, error)
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/auth/api/grpc/auth"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/pkg/grpcclient"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
type authentication struct {
|
||||
authSvcClient magistrala.AuthServiceClient
|
||||
}
|
||||
|
||||
var _ authn.Authentication = (*authentication)(nil)
|
||||
|
||||
func NewAuthentication(ctx context.Context, cfg grpcclient.Config) (authn.Authentication, grpcclient.Handler, error) {
|
||||
client, err := grpcclient.NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "auth",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, grpcclient.ErrSvcNotServing
|
||||
}
|
||||
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
|
||||
return authentication{authSvcClient}, client, nil
|
||||
}
|
||||
|
||||
func (a authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
|
||||
res, err := a.authSvcClient.Authenticate(ctx, &magistrala.AuthNReq{Token: token})
|
||||
if err != nil {
|
||||
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
|
||||
}
|
||||
return authn.Session{DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authn
|
||||
@@ -1,60 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
authn "github.com/absmach/magistrala/pkg/authn"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Authentication is an autogenerated mock type for the Authentication type
|
||||
type Authentication struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Authenticate provides a mock function with given fields: ctx, token
|
||||
func (_m *Authentication) Authenticate(ctx context.Context, token string) (authn.Session, error) {
|
||||
ret := _m.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Authenticate")
|
||||
}
|
||||
|
||||
var r0 authn.Session
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (authn.Session, error)); ok {
|
||||
return rf(ctx, token)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) authn.Session); ok {
|
||||
r0 = rf(ctx, token)
|
||||
} else {
|
||||
r0 = ret.Get(0).(authn.Session)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewAuthentication creates a new instance of Authentication. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAuthentication(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Authentication {
|
||||
mock := &Authentication{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authsvc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/auth/api/grpc/auth"
|
||||
"github.com/absmach/magistrala/pkg/authz"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/pkg/grpcclient"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
type authorization struct {
|
||||
authSvcClient magistrala.AuthServiceClient
|
||||
}
|
||||
|
||||
var _ authz.Authorization = (*authorization)(nil)
|
||||
|
||||
func NewAuthorization(ctx context.Context, cfg grpcclient.Config) (authz.Authorization, grpcclient.Handler, error) {
|
||||
client, err := grpcclient.NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "auth",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, grpcclient.ErrSvcNotServing
|
||||
}
|
||||
authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout)
|
||||
return authorization{authSvcClient}, client, nil
|
||||
}
|
||||
|
||||
func (a authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error {
|
||||
req := magistrala.AuthZReq{
|
||||
Domain: pr.Domain,
|
||||
SubjectType: pr.SubjectType,
|
||||
SubjectKind: pr.SubjectKind,
|
||||
SubjectRelation: pr.SubjectRelation,
|
||||
Subject: pr.Subject,
|
||||
Relation: pr.Relation,
|
||||
Permission: pr.Permission,
|
||||
Object: pr.Object,
|
||||
ObjectType: pr.ObjectType,
|
||||
}
|
||||
res, err := a.authSvcClient.Authorize(ctx, &req)
|
||||
if err != nil {
|
||||
return errors.Wrap(errors.ErrAuthorization, err)
|
||||
}
|
||||
if !res.Authorized {
|
||||
return errors.ErrAuthorization
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authz
|
||||
|
||||
import "context"
|
||||
|
||||
type PolicyReq struct {
|
||||
// Domain contains the domain ID.
|
||||
Domain string `json:"domain,omitempty"`
|
||||
|
||||
// Subject contains the subject ID or Token.
|
||||
Subject string `json:"subject"`
|
||||
|
||||
// SubjectType contains the subject type. Supported subject types are
|
||||
// platform, group, domain, thing, users.
|
||||
SubjectType string `json:"subject_type"`
|
||||
|
||||
// SubjectKind contains the subject kind. Supported subject kinds are
|
||||
// token, users, platform, things, channels, groups, domain.
|
||||
SubjectKind string `json:"subject_kind"`
|
||||
|
||||
// SubjectRelation contains subject relations.
|
||||
SubjectRelation string `json:"subject_relation,omitempty"`
|
||||
|
||||
// Object contains the object ID.
|
||||
Object string `json:"object"`
|
||||
|
||||
// ObjectKind contains the object kind. Supported object kinds are
|
||||
// users, platform, things, channels, groups, domain.
|
||||
ObjectKind string `json:"object_kind"`
|
||||
|
||||
// ObjectType contains the object type. Supported object types are
|
||||
// platform, group, domain, thing, users.
|
||||
ObjectType string `json:"object_type"`
|
||||
|
||||
// Relation contains the relation. Supported relations are administrator, editor, contributor, member, guest, parent_group,group,domain.
|
||||
Relation string `json:"relation,omitempty"`
|
||||
|
||||
// Permission contains the permission. Supported permissions are admin, delete, edit, share, view,
|
||||
// membership, create, admin_only, edit_only, view_only, membership_only, ext_admin, ext_edit, ext_view.
|
||||
Permission string `json:"permission,omitempty"`
|
||||
}
|
||||
|
||||
// Authz is magistrala authorization library.
|
||||
//
|
||||
//go:generate mockery --name Authorization --output=./mocks --filename authz.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Authorization interface {
|
||||
Authorize(ctx context.Context, pr PolicyReq) error
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authz
|
||||
@@ -1,50 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
authz "github.com/absmach/magistrala/pkg/authz"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Authorization is an autogenerated mock type for the Authorization type
|
||||
type Authorization struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Authorize provides a mock function with given fields: ctx, pr
|
||||
func (_m *Authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Authorize")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authz.PolicyReq) error); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewAuthorization creates a new instance of Authorization. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAuthorization(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Authorization {
|
||||
mock := &Authorization{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
UnpublishedEventsCheckInterval = 1 * time.Minute
|
||||
ConnCheckInterval = 100 * time.Millisecond
|
||||
MaxUnpublishedEvents uint64 = 1e4
|
||||
MaxEventStreamLen int64 = 1e6
|
||||
)
|
||||
|
||||
// Event represents an event.
|
||||
type Event interface {
|
||||
// Encode encodes event to map.
|
||||
Encode() (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// Publisher specifies events publishing API.
|
||||
//
|
||||
//go:generate mockery --name Publisher --output=./mocks --filename publisher.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Publisher interface {
|
||||
// Publish publishes event to stream.
|
||||
Publish(ctx context.Context, event Event) error
|
||||
|
||||
// Close gracefully closes event publisher's connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// EventHandler represents event handler for Subscriber.
|
||||
type EventHandler interface {
|
||||
// Handle handles events passed by underlying implementation.
|
||||
Handle(ctx context.Context, event Event) error
|
||||
}
|
||||
|
||||
// SubscriberConfig represents event subscriber configuration.
|
||||
type SubscriberConfig struct {
|
||||
Consumer string
|
||||
Stream string
|
||||
Handler EventHandler
|
||||
}
|
||||
|
||||
// Subscriber specifies event subscription API.
|
||||
//
|
||||
//go:generate mockery --name Subscriber --output=./mocks --filename subscriber.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Subscriber interface {
|
||||
// Subscribe subscribes to the event stream and consumes events.
|
||||
Subscribe(ctx context.Context, cfg SubscriberConfig) error
|
||||
|
||||
// Close gracefully closes event subscriber's connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Read reads value from event map.
|
||||
// If value is not of type T, returns default value.
|
||||
func Read[T any](event map[string]interface{}, key string, def T) T {
|
||||
val, ok := event[key].(T)
|
||||
if !ok {
|
||||
return def
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// ReadStringSlice reads string slice from event map.
|
||||
// If value is not a string slice, returns empty slice.
|
||||
func ReadStringSlice(event map[string]interface{}, key string) []string {
|
||||
var res []string
|
||||
|
||||
vals, ok := event[key].([]interface{})
|
||||
if !ok {
|
||||
return res
|
||||
}
|
||||
|
||||
for _, v := range vals {
|
||||
if s, ok := v.(string); ok {
|
||||
res = append(res, s)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
events "github.com/absmach/magistrala/pkg/events"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Publisher is an autogenerated mock type for the Publisher type
|
||||
type Publisher struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Publisher) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Close")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Publish provides a mock function with given fields: ctx, event
|
||||
func (_m *Publisher) Publish(ctx context.Context, event events.Event) error {
|
||||
ret := _m.Called(ctx, event)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Publish")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, events.Event) error); ok {
|
||||
r0 = rf(ctx, event)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewPublisher creates a new instance of Publisher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewPublisher(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Publisher {
|
||||
mock := &Publisher{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
events "github.com/absmach/magistrala/pkg/events"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Subscriber is an autogenerated mock type for the Subscriber type
|
||||
type Subscriber struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *Subscriber) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Close")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Subscribe provides a mock function with given fields: ctx, cfg
|
||||
func (_m *Subscriber) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
|
||||
ret := _m.Called(ctx, cfg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Subscribe")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, events.SubscriberConfig) error); ok {
|
||||
r0 = rf(ctx, cfg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewSubscriber creates a new instance of Subscriber. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewSubscriber(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Subscriber {
|
||||
mock := &Subscriber{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package redis contains the domain concept definitions needed to support
|
||||
// Magistrala redis events source service functionality.
|
||||
//
|
||||
// It provides the abstraction of the redis stream and its operations.
|
||||
package nats
|
||||
@@ -1,79 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/absmach/magistrala/pkg/messaging/nats"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
// Max message payload size is 1MB.
|
||||
var reconnectBufSize = 1024 * 1024 * int(events.MaxUnpublishedEvents)
|
||||
|
||||
type pubEventStore struct {
|
||||
url string
|
||||
conn *nats.Conn
|
||||
publisher messaging.Publisher
|
||||
stream string
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
|
||||
conn, err := nats.Connect(url, nats.MaxReconnects(maxReconnects), nats.ReconnectBufSize(reconnectBufSize))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := jetstream.New(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := js.CreateStream(ctx, jsStreamConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publisher, err := broker.NewPublisher(ctx, url, broker.Prefix(eventsPrefix), broker.JSStream(js))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
es := &pubEventStore{
|
||||
url: url,
|
||||
conn: conn,
|
||||
publisher: publisher,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
return es, nil
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
|
||||
values, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values["occurred_at"] = time.Now().UnixNano()
|
||||
|
||||
data, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record := &messaging.Message{
|
||||
Payload: data,
|
||||
}
|
||||
|
||||
return es.publisher.Publish(ctx, es.stream, record)
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Close() error {
|
||||
es.conn.Close()
|
||||
|
||||
return es.publisher.Close()
|
||||
}
|
||||
@@ -1,325 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/nats"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
eventsChan = make(chan map[string]interface{})
|
||||
logger = mglog.NewMock()
|
||||
errFailed = errors.New("failed")
|
||||
numEvents = 100
|
||||
)
|
||||
|
||||
type testEvent struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (te testEvent) Encode() (map[string]interface{}, error) {
|
||||
data := make(map[string]interface{})
|
||||
for k, v := range te.Data {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
data[k] = v
|
||||
case float64:
|
||||
data[k] = v
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data[k] = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
_, err := nats.NewPublisher(context.Background(), "http://invaliurl.com", stream)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
publisher, err := nats.NewPublisher(context.Background(), natsURL, stream)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer publisher.Close()
|
||||
|
||||
_, err = nats.NewSubscriber(context.Background(), "http://invaliurl.com", logger)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer subcriber.Close()
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
event map[string]interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish event successfully",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": "Earth",
|
||||
"status": "normal",
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish with nil event",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish event with invalid event location",
|
||||
err: fmt.Errorf("json: unsupported type: chan int"),
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": make(chan int),
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish event with nested sting value",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": map[string]string{
|
||||
"lat": fmt.Sprintf("%f", rand.Float64()),
|
||||
"lng": fmt.Sprintf("%f", rand.Float64()),
|
||||
},
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
event := testEvent{Data: tc.event}
|
||||
|
||||
err := publisher.Publish(context.Background(), event)
|
||||
switch tc.err {
|
||||
case nil:
|
||||
receivedEvent := <-eventsChan
|
||||
|
||||
val := int64(receivedEvent["occurred_at"].(float64))
|
||||
if assert.WithinRange(t, time.Unix(0, val), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
|
||||
delete(receivedEvent, "occurred_at")
|
||||
delete(tc.event, "occurred_at")
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
|
||||
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
|
||||
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
|
||||
assert.Equal(t, tc.event["status"], receivedEvent["status"])
|
||||
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
|
||||
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
|
||||
default:
|
||||
assert.ErrorContains(t, err, tc.err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
stream string
|
||||
consumer string
|
||||
err error
|
||||
handler events.EventHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with an empty consumer",
|
||||
stream: "",
|
||||
consumer: "",
|
||||
err: nats.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with a valid consumer",
|
||||
stream: "",
|
||||
consumer: consumer,
|
||||
err: nats.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a valid stream with an empty consumer",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: "",
|
||||
err: nats.ErrEmptyConsumer,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to another stream",
|
||||
stream: fmt.Sprintf("events.%s.%d", stream, 1),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a stream with malformed handler",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
|
||||
if err != nil {
|
||||
assert.Equal(t, err, tc.err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: tc.stream,
|
||||
Consumer: tc.consumer,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
switch err := subcriber.Subscribe(context.Background(), cfg); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err)
|
||||
default:
|
||||
assert.Equal(t, err, tc.err)
|
||||
}
|
||||
|
||||
err = subcriber.Close()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnavailablePublish(t *testing.T) {
|
||||
publisher, err := nats.NewPublisher(context.Background(), natsURL, stream)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
subcriber, err := nats.NewSubscriber(context.Background(), natsURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
err = pool.Client.PauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
|
||||
|
||||
spawnGoroutines(publisher, t)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = pool.Client.UnpauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
|
||||
|
||||
// Wait for the events to be published.
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = publisher.Close()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
|
||||
|
||||
// read all the events from the channel and assert that they are 10.
|
||||
var receivedEvents []map[string]interface{}
|
||||
for i := 0; i < numEvents; i++ {
|
||||
event := <-eventsChan
|
||||
receivedEvents = append(receivedEvents, event)
|
||||
}
|
||||
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
|
||||
}
|
||||
|
||||
func generateRandomEvent() testEvent {
|
||||
return testEvent{
|
||||
Data: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"location": fmt.Sprintf("%f", rand.Float64()),
|
||||
"status": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
|
||||
for i := 0; i < numEvents; i++ {
|
||||
go func() {
|
||||
err := publisher.Publish(context.Background(), generateRandomEvent())
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
fail bool
|
||||
}
|
||||
|
||||
func (h handler) Handle(_ context.Context, event events.Event) error {
|
||||
if h.fail {
|
||||
return errFailed
|
||||
}
|
||||
data, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsChan <- data
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events/nats"
|
||||
"github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
natsURL string
|
||||
stream = "tests.events"
|
||||
consumer = "tests-consumer"
|
||||
pool *dockertest.Pool
|
||||
container *dockertest.Resource
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
pool, err = dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err = pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "nats",
|
||||
Tag: "2.10.9-alpine",
|
||||
Cmd: []string{"-DVV", "-js"},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
natsURL = fmt.Sprintf("nats://%s:%s", "localhost", container.GetPort("4222/tcp"))
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
_, err = nats.NewPublisher(context.Background(), natsURL, stream)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
_, err = nats.NewSubscriber(context.Background(), natsURL, logger)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/absmach/magistrala/pkg/messaging/nats"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
const maxReconnects = -1
|
||||
|
||||
var _ events.Subscriber = (*subEventStore)(nil)
|
||||
|
||||
var (
|
||||
eventsPrefix = "events"
|
||||
|
||||
jsStreamConfig = jetstream.StreamConfig{
|
||||
Name: "events",
|
||||
Description: "Magistrala stream for sending and receiving messages in between Magistrala events",
|
||||
Subjects: []string{"events.>"},
|
||||
Retention: jetstream.LimitsPolicy,
|
||||
MaxMsgsPerSubject: 1e9,
|
||||
MaxAge: time.Hour * 24,
|
||||
MaxMsgSize: 1024 * 1024,
|
||||
Discard: jetstream.DiscardOld,
|
||||
Storage: jetstream.FileStorage,
|
||||
}
|
||||
|
||||
// ErrEmptyStream is returned when stream name is empty.
|
||||
ErrEmptyStream = errors.New("stream name cannot be empty")
|
||||
|
||||
// ErrEmptyConsumer is returned when consumer name is empty.
|
||||
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
|
||||
)
|
||||
|
||||
type subEventStore struct {
|
||||
conn *nats.Conn
|
||||
pubsub messaging.PubSub
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewSubscriber(ctx context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
conn, err := nats.Connect(url, nats.MaxReconnects(maxReconnects))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := jetstream.New(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsStream, err := js.CreateStream(ctx, jsStreamConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubsub, err := broker.NewPubSub(ctx, url, logger, broker.Stream(jsStream))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subEventStore{
|
||||
conn: conn,
|
||||
pubsub: pubsub,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
|
||||
if cfg.Stream == "" {
|
||||
return ErrEmptyStream
|
||||
}
|
||||
if cfg.Consumer == "" {
|
||||
return ErrEmptyConsumer
|
||||
}
|
||||
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: cfg.Consumer,
|
||||
Topic: cfg.Stream,
|
||||
Handler: &eventHandler{
|
||||
handler: cfg.Handler,
|
||||
ctx: ctx,
|
||||
logger: es.logger,
|
||||
},
|
||||
DeliveryPolicy: messaging.DeliverNewPolicy,
|
||||
}
|
||||
|
||||
return es.pubsub.Subscribe(ctx, subCfg)
|
||||
}
|
||||
|
||||
func (es *subEventStore) Close() error {
|
||||
es.conn.Close()
|
||||
return es.pubsub.Close()
|
||||
}
|
||||
|
||||
type event struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (re event) Encode() (map[string]interface{}, error) {
|
||||
return re.Data, nil
|
||||
}
|
||||
|
||||
type eventHandler struct {
|
||||
handler events.EventHandler
|
||||
ctx context.Context
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (eh *eventHandler) Handle(msg *messaging.Message) error {
|
||||
event := event{
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg.GetPayload(), &event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := eh.handler.Handle(eh.ctx, event); err != nil {
|
||||
eh.logger.Warn(fmt.Sprintf("failed to handle nats event: %s", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eh *eventHandler) Cancel() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package redis contains the domain concept definitions needed to support
|
||||
// Magistrala redis events source service functionality.
|
||||
//
|
||||
// It provides the abstraction of the redis stream and its operations.
|
||||
package rabbitmq
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/absmach/magistrala/pkg/messaging/rabbitmq"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
type pubEventStore struct {
|
||||
conn *amqp.Connection
|
||||
publisher messaging.Publisher
|
||||
stream string
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
|
||||
conn, err := amqp.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, err := conn.Channel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publisher, err := broker.NewPublisher(url, broker.Prefix(eventsPrefix), broker.Exchange(exchangeName), broker.Channel(ch))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
es := &pubEventStore{
|
||||
conn: conn,
|
||||
publisher: publisher,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
return es, nil
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
|
||||
values, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values["occurred_at"] = time.Now().UnixNano()
|
||||
|
||||
data, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record := &messaging.Message{
|
||||
Payload: data,
|
||||
}
|
||||
|
||||
return es.publisher.Publish(ctx, es.stream, record)
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Close() error {
|
||||
es.conn.Close()
|
||||
|
||||
return es.publisher.Close()
|
||||
}
|
||||
@@ -1,326 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/rabbitmq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
eventsChan = make(chan map[string]interface{})
|
||||
logger = mglog.NewMock()
|
||||
errFailed = errors.New("failed")
|
||||
numEvents = 100
|
||||
)
|
||||
|
||||
type testEvent struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (te testEvent) Encode() (map[string]interface{}, error) {
|
||||
data := make(map[string]interface{})
|
||||
for k, v := range te.Data {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
data[k] = v
|
||||
case float64:
|
||||
data[k] = v
|
||||
default:
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data[k] = string(b)
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
_, err := rabbitmq.NewPublisher(context.Background(), "http://invaliurl.com", stream)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer publisher.Close()
|
||||
|
||||
_, err = rabbitmq.NewSubscriber("http://invaliurl.com", logger)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer subcriber.Close()
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
event map[string]interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish event successfully",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": "Earth",
|
||||
"status": "normal",
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish with nil event",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish event with invalid event location",
|
||||
err: fmt.Errorf("json: unsupported type: chan int"),
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": make(chan int),
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish event with nested sting value",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": map[string]string{
|
||||
"lat": fmt.Sprintf("%f", rand.Float64()),
|
||||
"lng": fmt.Sprintf("%f", rand.Float64()),
|
||||
},
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
event := testEvent{Data: tc.event}
|
||||
|
||||
err := publisher.Publish(context.Background(), event)
|
||||
switch tc.err {
|
||||
case nil:
|
||||
receivedEvent := <-eventsChan
|
||||
|
||||
val := int64(receivedEvent["occurred_at"].(float64))
|
||||
if assert.WithinRange(t, time.Unix(0, val), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
|
||||
delete(receivedEvent, "occurred_at")
|
||||
delete(tc.event, "occurred_at")
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
|
||||
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
|
||||
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
|
||||
assert.Equal(t, tc.event["status"], receivedEvent["status"])
|
||||
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
|
||||
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
|
||||
|
||||
default:
|
||||
assert.ErrorContains(t, err, tc.err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
stream string
|
||||
consumer string
|
||||
err error
|
||||
handler events.EventHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with an empty consumer",
|
||||
stream: "",
|
||||
consumer: "",
|
||||
err: rabbitmq.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with a valid consumer",
|
||||
stream: "",
|
||||
consumer: consumer,
|
||||
err: rabbitmq.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a valid stream with an empty consumer",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: "",
|
||||
err: rabbitmq.ErrEmptyConsumer,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to another stream",
|
||||
stream: fmt.Sprintf("events.%s.%d", stream, 1),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a stream with malformed handler",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
|
||||
if err != nil {
|
||||
assert.Equal(t, err, tc.err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: tc.stream,
|
||||
Consumer: tc.consumer,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
switch err := subcriber.Subscribe(context.Background(), cfg); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err)
|
||||
default:
|
||||
assert.Equal(t, err, tc.err)
|
||||
}
|
||||
|
||||
err = subcriber.Close()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnavailablePublish(t *testing.T) {
|
||||
publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
err = pool.Client.PauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
|
||||
|
||||
spawnGoroutines(publisher, t)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = pool.Client.UnpauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
|
||||
|
||||
// Wait for the events to be published.
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = publisher.Close()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
|
||||
|
||||
// read all the events from the channel and assert that they are 10.
|
||||
var receivedEvents []map[string]interface{}
|
||||
for i := 0; i < numEvents; i++ {
|
||||
event := <-eventsChan
|
||||
receivedEvents = append(receivedEvents, event)
|
||||
}
|
||||
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
|
||||
}
|
||||
|
||||
func generateRandomEvent() testEvent {
|
||||
return testEvent{
|
||||
Data: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"location": fmt.Sprintf("%f", rand.Float64()),
|
||||
"status": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
|
||||
for i := 0; i < numEvents; i++ {
|
||||
go func() {
|
||||
err := publisher.Publish(context.Background(), generateRandomEvent())
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
fail bool
|
||||
}
|
||||
|
||||
func (h handler) Handle(_ context.Context, event events.Event) error {
|
||||
if h.fail {
|
||||
return errFailed
|
||||
}
|
||||
data, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsChan <- data
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events/rabbitmq"
|
||||
"github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
rabbitmqURL string
|
||||
stream = "tests.events"
|
||||
consumer = "tests-consumer"
|
||||
pool *dockertest.Pool
|
||||
container *dockertest.Resource
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
pool, err = dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err = pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "rabbitmq",
|
||||
Tag: "3.12.12",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
rabbitmqURL = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort("5672/tcp"))
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
_, err = rabbitmq.NewPublisher(context.Background(), rabbitmqURL, stream)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
_, err = rabbitmq.NewSubscriber(rabbitmqURL, logger)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 2)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -1,122 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/absmach/magistrala/pkg/messaging/rabbitmq"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
var _ events.Subscriber = (*subEventStore)(nil)
|
||||
|
||||
var (
|
||||
exchangeName = "events"
|
||||
eventsPrefix = "events"
|
||||
|
||||
// ErrEmptyStream is returned when stream name is empty.
|
||||
ErrEmptyStream = errors.New("stream name cannot be empty")
|
||||
|
||||
// ErrEmptyConsumer is returned when consumer name is empty.
|
||||
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
|
||||
)
|
||||
|
||||
type subEventStore struct {
|
||||
conn *amqp.Connection
|
||||
pubsub messaging.PubSub
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewSubscriber(url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
conn, err := amqp.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, err := conn.Channel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubsub, err := broker.NewPubSub(url, logger, broker.Channel(ch), broker.Exchange(exchangeName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subEventStore{
|
||||
conn: conn,
|
||||
pubsub: pubsub,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
|
||||
if cfg.Stream == "" {
|
||||
return ErrEmptyStream
|
||||
}
|
||||
if cfg.Consumer == "" {
|
||||
return ErrEmptyConsumer
|
||||
}
|
||||
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: cfg.Consumer,
|
||||
Topic: cfg.Stream,
|
||||
Handler: &eventHandler{
|
||||
handler: cfg.Handler,
|
||||
ctx: ctx,
|
||||
logger: es.logger,
|
||||
},
|
||||
DeliveryPolicy: messaging.DeliverNewPolicy,
|
||||
}
|
||||
|
||||
return es.pubsub.Subscribe(ctx, subCfg)
|
||||
}
|
||||
|
||||
func (es *subEventStore) Close() error {
|
||||
es.conn.Close()
|
||||
return es.pubsub.Close()
|
||||
}
|
||||
|
||||
type event struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (re event) Encode() (map[string]interface{}, error) {
|
||||
return re.Data, nil
|
||||
}
|
||||
|
||||
type eventHandler struct {
|
||||
handler events.EventHandler
|
||||
ctx context.Context
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (eh *eventHandler) Handle(msg *messaging.Message) error {
|
||||
event := event{
|
||||
Data: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg.GetPayload(), &event.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := eh.handler.Handle(eh.ctx, event); err != nil {
|
||||
eh.logger.Warn(fmt.Sprintf("failed to handle rabbitmq event: %s", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (eh *eventHandler) Cancel() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package redis contains the domain concept definitions needed to support
|
||||
// Magistrala redis events source service functionality.
|
||||
//
|
||||
// It provides the abstraction of the redis stream and its operations.
|
||||
package redis
|
||||
@@ -1,118 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type pubEventStore struct {
|
||||
client *redis.Client
|
||||
unpublishedEvents chan *redis.XAddArgs
|
||||
stream string
|
||||
mu sync.Mutex
|
||||
flushPeriod time.Duration
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string, flushPeriod time.Duration) (events.Publisher, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
es := &pubEventStore{
|
||||
client: redis.NewClient(opts),
|
||||
unpublishedEvents: make(chan *redis.XAddArgs, events.MaxUnpublishedEvents),
|
||||
stream: eventsPrefix + stream,
|
||||
flushPeriod: flushPeriod,
|
||||
}
|
||||
|
||||
go es.flushUnpublished(ctx)
|
||||
|
||||
return es, nil
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Publish(ctx context.Context, event events.Event) error {
|
||||
values, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values["occurred_at"] = time.Now().UnixNano()
|
||||
|
||||
data, err := json.Marshal(values)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
record := &redis.XAddArgs{
|
||||
Stream: es.stream,
|
||||
MaxLen: events.MaxEventStreamLen,
|
||||
Approx: true,
|
||||
Values: map[string]interface{}{"data": string(data)},
|
||||
}
|
||||
|
||||
switch err := es.checkConnection(ctx); err {
|
||||
case nil:
|
||||
return es.client.XAdd(ctx, record).Err()
|
||||
default:
|
||||
es.mu.Lock()
|
||||
defer es.mu.Unlock()
|
||||
|
||||
// If the channel is full (rarely happens), drop the events.
|
||||
if len(es.unpublishedEvents) == int(events.MaxUnpublishedEvents) {
|
||||
return nil
|
||||
}
|
||||
|
||||
es.unpublishedEvents <- record
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// flushUnpublished periodically checks the Redis connection and publishes
|
||||
// the events that were not published due to a connection error.
|
||||
func (es *pubEventStore) flushUnpublished(ctx context.Context) {
|
||||
defer close(es.unpublishedEvents)
|
||||
|
||||
ticker := time.NewTicker(es.flushPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := es.checkConnection(ctx); err == nil {
|
||||
es.mu.Lock()
|
||||
for i := len(es.unpublishedEvents) - 1; i >= 0; i-- {
|
||||
record := <-es.unpublishedEvents
|
||||
if err := es.client.XAdd(ctx, record).Err(); err != nil {
|
||||
es.unpublishedEvents <- record
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
es.mu.Unlock()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (es *pubEventStore) Close() error {
|
||||
return es.client.Close()
|
||||
}
|
||||
|
||||
func (es *pubEventStore) checkConnection(ctx context.Context) error {
|
||||
// A timeout is used to avoid blocking the main thread
|
||||
ctx, cancel := context.WithTimeout(ctx, events.ConnCheckInterval)
|
||||
defer cancel()
|
||||
|
||||
return es.client.Ping(ctx).Err()
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/redis"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
stream = "tests.events"
|
||||
consumer = "test-consumer"
|
||||
eventsChan = make(chan map[string]interface{})
|
||||
logger = mglog.NewMock()
|
||||
errFailed = errors.New("failed")
|
||||
numEvents = 100
|
||||
)
|
||||
|
||||
type testEvent struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (te testEvent) Encode() (map[string]interface{}, error) {
|
||||
if te.Data == nil {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
return te.Data, nil
|
||||
}
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on flushing redis: %s", err))
|
||||
|
||||
_, err = redis.NewPublisher(context.Background(), "http://invaliurl.com", stream, events.UnpublishedEventsCheckInterval)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
publisher, err := redis.NewPublisher(context.Background(), redisURL, stream, events.UnpublishedEventsCheckInterval)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer publisher.Close()
|
||||
|
||||
_, err = redis.NewSubscriber("http://invaliurl.com", logger)
|
||||
assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err)
|
||||
|
||||
subcriber, err := redis.NewSubscriber(redisURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
defer subcriber.Close()
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
event map[string]interface{}
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish event successfully",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": float64(rand.Float64()),
|
||||
"humidity": float64(rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": "Earth",
|
||||
"status": "normal",
|
||||
"timestamp": float64(time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish with nil event",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish event with invalid event location",
|
||||
err: fmt.Errorf("json: unsupported type: chan int"),
|
||||
event: map[string]interface{}{
|
||||
"temperature": float64(rand.Float64()),
|
||||
"humidity": float64(rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": make(chan int),
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": float64(time.Now().UnixNano()),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "publish event with nested sting value",
|
||||
err: nil,
|
||||
event: map[string]interface{}{
|
||||
"temperature": float64(rand.Float64()),
|
||||
"humidity": float64(rand.Float64()),
|
||||
"sensor_id": "abc123",
|
||||
"location": map[string]string{
|
||||
"lat": fmt.Sprintf("%f", rand.Float64()),
|
||||
"lng": fmt.Sprintf("%f", rand.Float64()),
|
||||
},
|
||||
"status": "normal",
|
||||
"timestamp": "invalid",
|
||||
"operation": "create",
|
||||
"occurred_at": float64(time.Now().UnixNano()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
event := testEvent{Data: tc.event}
|
||||
|
||||
err := publisher.Publish(context.Background(), event)
|
||||
switch tc.err {
|
||||
case nil:
|
||||
receivedEvent := <-eventsChan
|
||||
|
||||
roa := receivedEvent["occurred_at"].(float64)
|
||||
assert.Nil(t, err)
|
||||
if assert.WithinRange(t, time.Unix(0, int64(roa)), time.Now().Add(-time.Second), time.Now().Add(time.Second)) {
|
||||
delete(receivedEvent, "occurred_at")
|
||||
delete(tc.event, "occurred_at")
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"])
|
||||
assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"])
|
||||
assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"])
|
||||
assert.Equal(t, tc.event["status"], receivedEvent["status"])
|
||||
assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"])
|
||||
assert.Equal(t, tc.event["operation"], receivedEvent["operation"])
|
||||
|
||||
default:
|
||||
assert.ErrorContains(t, err, tc.err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on flushing redis: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
stream string
|
||||
consumer string
|
||||
err error
|
||||
handler events.EventHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same stream",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with an empty consumer",
|
||||
stream: "",
|
||||
consumer: "",
|
||||
err: redis.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty stream with a valid consumer",
|
||||
stream: "",
|
||||
consumer: consumer,
|
||||
err: redis.ErrEmptyStream,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a valid stream with an empty consumer",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: "",
|
||||
err: redis.ErrEmptyConsumer,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to another stream",
|
||||
stream: fmt.Sprintf("events.%s.%d", stream, 1),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{false},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a stream with malformed handler",
|
||||
stream: fmt.Sprintf("events.%s", stream),
|
||||
consumer: consumer,
|
||||
err: nil,
|
||||
handler: handler{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
subcriber, err := redis.NewSubscriber(redisURL, logger)
|
||||
if err != nil {
|
||||
assert.Equal(t, err, tc.err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: tc.stream,
|
||||
Consumer: tc.consumer,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
switch err := subcriber.Subscribe(context.Background(), cfg); {
|
||||
case err == nil:
|
||||
assert.Nil(t, err)
|
||||
default:
|
||||
assert.Equal(t, err, tc.err)
|
||||
}
|
||||
|
||||
err = subcriber.Close()
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnavailablePublish(t *testing.T) {
|
||||
publisher, err := redis.NewPublisher(context.Background(), redisURL, stream, time.Second)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
subcriber, err := redis.NewSubscriber(redisURL, logger)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err))
|
||||
|
||||
cfg := events.SubscriberConfig{
|
||||
Stream: "events." + stream,
|
||||
Consumer: consumer,
|
||||
Handler: handler{},
|
||||
}
|
||||
err = subcriber.Subscribe(context.Background(), cfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err))
|
||||
|
||||
err = pool.Client.PauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err))
|
||||
|
||||
spawnGoroutines(publisher, t)
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = pool.Client.UnpauseContainer(container.Container.ID)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err))
|
||||
|
||||
// Wait for the events to be published.
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = publisher.Close()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err))
|
||||
|
||||
var receivedEvents []map[string]interface{}
|
||||
for i := 0; i < numEvents; i++ {
|
||||
event := <-eventsChan
|
||||
receivedEvents = append(receivedEvents, event)
|
||||
}
|
||||
assert.Len(t, receivedEvents, numEvents, "got unexpected number of events")
|
||||
}
|
||||
|
||||
func generateRandomEvent() testEvent {
|
||||
return testEvent{
|
||||
Data: map[string]interface{}{
|
||||
"temperature": fmt.Sprintf("%f", rand.Float64()),
|
||||
"humidity": fmt.Sprintf("%f", rand.Float64()),
|
||||
"sensor_id": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"location": fmt.Sprintf("%f", rand.Float64()),
|
||||
"status": fmt.Sprintf("%d", rand.Intn(1000)),
|
||||
"timestamp": fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
"operation": "create",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func spawnGoroutines(publisher events.Publisher, t *testing.T) {
|
||||
for i := 0; i < numEvents; i++ {
|
||||
go func() {
|
||||
err := publisher.Publish(context.Background(), generateRandomEvent())
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
fail bool
|
||||
}
|
||||
|
||||
func (h handler) Handle(_ context.Context, event events.Event) error {
|
||||
if h.fail {
|
||||
return errFailed
|
||||
}
|
||||
data, err := event.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsChan <- data
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package redis_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
redisClient *redis.Client
|
||||
redisURL string
|
||||
pool *dockertest.Pool
|
||||
container *dockertest.Resource
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
var err error
|
||||
pool, err = dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err = pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "redis",
|
||||
Tag: "7.2.4-alpine",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
|
||||
ropts, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not parse redis URL: %s", err)
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
redisClient = redis.NewClient(ropts)
|
||||
|
||||
return redisClient.Ping(context.Background()).Err()
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
eventsPrefix = "events."
|
||||
eventCount = 100
|
||||
exists = "BUSYGROUP Consumer Group name already exists"
|
||||
group = "magistrala"
|
||||
)
|
||||
|
||||
var _ events.Subscriber = (*subEventStore)(nil)
|
||||
|
||||
var (
|
||||
// ErrEmptyStream is returned when stream name is empty.
|
||||
ErrEmptyStream = errors.New("stream name cannot be empty")
|
||||
|
||||
// ErrEmptyConsumer is returned when consumer name is empty.
|
||||
ErrEmptyConsumer = errors.New("consumer name cannot be empty")
|
||||
)
|
||||
|
||||
type subEventStore struct {
|
||||
client *redis.Client
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewSubscriber(url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
opts, err := redis.ParseURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &subEventStore{
|
||||
client: redis.NewClient(opts),
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error {
|
||||
if cfg.Stream == "" {
|
||||
return ErrEmptyStream
|
||||
}
|
||||
if cfg.Consumer == "" {
|
||||
return ErrEmptyConsumer
|
||||
}
|
||||
|
||||
err := es.client.XGroupCreateMkStream(ctx, cfg.Stream, group, "$").Err()
|
||||
if err != nil && err.Error() != exists {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
msgs, err := es.client.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||
Group: group,
|
||||
Consumer: cfg.Consumer,
|
||||
Streams: []string{cfg.Stream, ">"},
|
||||
Count: eventCount,
|
||||
}).Result()
|
||||
if err != nil {
|
||||
es.logger.Warn(fmt.Sprintf("failed to read from redis stream: %s", err))
|
||||
|
||||
continue
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
es.handle(ctx, cfg.Stream, msgs[0].Messages, cfg.Handler)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *subEventStore) Close() error {
|
||||
return es.client.Close()
|
||||
}
|
||||
|
||||
type redisEvent struct {
|
||||
Data map[string]interface{}
|
||||
}
|
||||
|
||||
func (re redisEvent) Encode() (map[string]interface{}, error) {
|
||||
return re.Data, nil
|
||||
}
|
||||
|
||||
func (es *subEventStore) handle(ctx context.Context, stream string, msgs []redis.XMessage, h events.EventHandler) {
|
||||
for _, msg := range msgs {
|
||||
var data map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(msg.Values["data"].(string)), &data); err != nil {
|
||||
es.logger.Warn(fmt.Sprintf("failed to unmarshal redis event: %s", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
event := redisEvent{
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if err := h.Handle(ctx, event); err != nil {
|
||||
es.logger.Warn(fmt.Sprintf("failed to handle redis event: %s", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := es.client.XAck(ctx, stream, group, msg.ID).Err(); err != nil {
|
||||
es.logger.Warn(fmt.Sprintf("failed to ack redis event: %s", err))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build nats
|
||||
// +build nats
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/nats"
|
||||
)
|
||||
|
||||
// StreamAllEvents represents subject to subscribe for all the events.
|
||||
const StreamAllEvents = "events.>"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using nats as the events store")
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
|
||||
pb, err := nats.NewPublisher(ctx, url, stream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
|
||||
func NewSubscriber(ctx context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
pb, err := nats.NewSubscriber(ctx, url, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build rabbitmq
|
||||
// +build rabbitmq
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/rabbitmq"
|
||||
)
|
||||
|
||||
// StreamAllEvents represents subject to subscribe for all the events.
|
||||
const StreamAllEvents = "events.#"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using rabbitmq as the events store")
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
|
||||
pb, err := rabbitmq.NewPublisher(ctx, url, stream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
|
||||
func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
pb, err := rabbitmq.NewSubscriber(url, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !nats && !rabbitmq
|
||||
// +build !nats,!rabbitmq
|
||||
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/redis"
|
||||
)
|
||||
|
||||
// StreamAllEvents represents subject to subscribe for all the events.
|
||||
const StreamAllEvents = ">"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using redis as the events store")
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url, stream string) (events.Publisher, error) {
|
||||
pb, err := redis.NewPublisher(ctx, url, stream, events.UnpublishedEventsCheckInterval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
|
||||
func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) {
|
||||
pb, err := redis.NewSubscriber(url, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package groups contains the domain concept definitions needed to support
|
||||
// Magistrala groups functionality.
|
||||
package groups
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package groups
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidStatus indicates invalid status.
|
||||
ErrInvalidStatus = errors.New("invalid groups status")
|
||||
|
||||
// ErrEnableGroup indicates error in enabling group.
|
||||
ErrEnableGroup = errors.New("failed to enable group")
|
||||
|
||||
// ErrDisableGroup indicates error in disabling group.
|
||||
ErrDisableGroup = errors.New("failed to disable group")
|
||||
)
|
||||
@@ -1,133 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package groups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
)
|
||||
|
||||
// MaxLevel represents the maximum group hierarchy level.
|
||||
const MaxLevel = uint64(5)
|
||||
|
||||
// Group represents the group of Clients.
|
||||
// Indicates a level in tree hierarchy. Root node is level 1.
|
||||
// Path in a tree consisting of group IDs
|
||||
// Paths are unique per domain.
|
||||
type Group struct {
|
||||
ID string `json:"id"`
|
||||
Domain string `json:"domain_id,omitempty"`
|
||||
Parent string `json:"parent_id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Level int `json:"level,omitempty"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Children []*Group `json:"children,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
Permissions []string `json:"permissions,omitempty"`
|
||||
}
|
||||
|
||||
type Member struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// Memberships contains page related metadata as well as list of memberships that
|
||||
// belong to this page.
|
||||
type MembersPage struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Members []Member `json:"members"`
|
||||
}
|
||||
|
||||
// Page contains page related metadata as well as list
|
||||
// of Groups that belong to the page.
|
||||
type Page struct {
|
||||
PageMeta
|
||||
Path string
|
||||
Level uint64
|
||||
ParentID string
|
||||
Permission string
|
||||
ListPerms bool
|
||||
Direction int64 // ancestors (+1) or descendants (-1)
|
||||
Groups []Group
|
||||
}
|
||||
|
||||
// Metadata represents arbitrary JSON.
|
||||
type Metadata map[string]interface{}
|
||||
|
||||
// Repository specifies a group persistence API.
|
||||
//
|
||||
//go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines" --unroll-variadic=false
|
||||
type Repository interface {
|
||||
// Save group.
|
||||
Save(ctx context.Context, g Group) (Group, error)
|
||||
|
||||
// Update a group.
|
||||
Update(ctx context.Context, g Group) (Group, error)
|
||||
|
||||
// RetrieveByID retrieves group by its id.
|
||||
RetrieveByID(ctx context.Context, id string) (Group, error)
|
||||
|
||||
// RetrieveAll retrieves all groups.
|
||||
RetrieveAll(ctx context.Context, gm Page) (Page, error)
|
||||
|
||||
// RetrieveByIDs retrieves group by ids and query.
|
||||
RetrieveByIDs(ctx context.Context, gm Page, ids ...string) (Page, error)
|
||||
|
||||
// ChangeStatus changes groups status to active or inactive
|
||||
ChangeStatus(ctx context.Context, group Group) (Group, error)
|
||||
|
||||
// AssignParentGroup assigns parent group id to a given group id
|
||||
AssignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error
|
||||
|
||||
// UnassignParentGroup unassign parent group id fr given group id
|
||||
UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error
|
||||
|
||||
// Delete a group
|
||||
Delete(ctx context.Context, groupID string) error
|
||||
}
|
||||
|
||||
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" --unroll-variadic=false
|
||||
type Service interface {
|
||||
// CreateGroup creates new group.
|
||||
CreateGroup(ctx context.Context, session authn.Session, kind string, g Group) (Group, error)
|
||||
|
||||
// UpdateGroup updates the group identified by the provided ID.
|
||||
UpdateGroup(ctx context.Context, session authn.Session, g Group) (Group, error)
|
||||
|
||||
// ViewGroup retrieves data about the group identified by ID.
|
||||
ViewGroup(ctx context.Context, session authn.Session, id string) (Group, error)
|
||||
|
||||
// ViewGroupPerms retrieves permissions on the group id for the given authorized token.
|
||||
ViewGroupPerms(ctx context.Context, session authn.Session, id string) ([]string, error)
|
||||
|
||||
// ListGroups retrieves a list of groups basesd on entity type and entity id.
|
||||
ListGroups(ctx context.Context, session authn.Session, memberKind, memberID string, gm Page) (Page, error)
|
||||
|
||||
// ListMembers retrieves everything that is assigned to a group identified by groupID.
|
||||
ListMembers(ctx context.Context, session authn.Session, groupID, permission, memberKind string) (MembersPage, error)
|
||||
|
||||
// EnableGroup logically enables the group identified with the provided ID.
|
||||
EnableGroup(ctx context.Context, session authn.Session, id string) (Group, error)
|
||||
|
||||
// DisableGroup logically disables the group identified with the provided ID.
|
||||
DisableGroup(ctx context.Context, session authn.Session, id string) (Group, error)
|
||||
|
||||
// DeleteGroup delete the given group id
|
||||
DeleteGroup(ctx context.Context, session authn.Session, id string) error
|
||||
|
||||
// Assign member to group
|
||||
Assign(ctx context.Context, session authn.Session, groupID, relation, memberKind string, memberIDs ...string) (err error)
|
||||
|
||||
// Unassign member from group
|
||||
Unassign(ctx context.Context, session authn.Session, groupID, relation, memberKind string, memberIDs ...string) (err error)
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package mocks contains mocks for testing purposes.
|
||||
package mocks
|
||||
@@ -1,253 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
groups "github.com/absmach/magistrala/pkg/groups"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Repository is an autogenerated mock type for the Repository type
|
||||
type Repository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// AssignParentGroup provides a mock function with given fields: ctx, parentGroupID, groupIDs
|
||||
func (_m *Repository) AssignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error {
|
||||
ret := _m.Called(ctx, parentGroupID, groupIDs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AssignParentGroup")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok {
|
||||
r0 = rf(ctx, parentGroupID, groupIDs...)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ChangeStatus provides a mock function with given fields: ctx, group
|
||||
func (_m *Repository) ChangeStatus(ctx context.Context, group groups.Group) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, group)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ChangeStatus")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
|
||||
return rf(ctx, group)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
|
||||
r0 = rf(ctx, group)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
|
||||
r1 = rf(ctx, group)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Delete provides a mock function with given fields: ctx, groupID
|
||||
func (_m *Repository) Delete(ctx context.Context, groupID string) error {
|
||||
ret := _m.Called(ctx, groupID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Delete")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
|
||||
r0 = rf(ctx, groupID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// RetrieveAll provides a mock function with given fields: ctx, gm
|
||||
func (_m *Repository) RetrieveAll(ctx context.Context, gm groups.Page) (groups.Page, error) {
|
||||
ret := _m.Called(ctx, gm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveAll")
|
||||
}
|
||||
|
||||
var r0 groups.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Page) (groups.Page, error)); ok {
|
||||
return rf(ctx, gm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Page) groups.Page); ok {
|
||||
r0 = rf(ctx, gm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, groups.Page) error); ok {
|
||||
r1 = rf(ctx, gm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveByID provides a mock function with given fields: ctx, id
|
||||
func (_m *Repository) RetrieveByID(ctx context.Context, id string) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByID")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (groups.Group, error)); ok {
|
||||
return rf(ctx, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) groups.Group); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// RetrieveByIDs provides a mock function with given fields: ctx, gm, ids
|
||||
func (_m *Repository) RetrieveByIDs(ctx context.Context, gm groups.Page, ids ...string) (groups.Page, error) {
|
||||
ret := _m.Called(ctx, gm, ids)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByIDs")
|
||||
}
|
||||
|
||||
var r0 groups.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Page, ...string) (groups.Page, error)); ok {
|
||||
return rf(ctx, gm, ids...)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Page, ...string) groups.Page); ok {
|
||||
r0 = rf(ctx, gm, ids...)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, groups.Page, ...string) error); ok {
|
||||
r1 = rf(ctx, gm, ids...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Save provides a mock function with given fields: ctx, g
|
||||
func (_m *Repository) Save(ctx context.Context, g groups.Group) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, g)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Save")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
|
||||
return rf(ctx, g)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
|
||||
r0 = rf(ctx, g)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
|
||||
r1 = rf(ctx, g)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// UnassignParentGroup provides a mock function with given fields: ctx, parentGroupID, groupIDs
|
||||
func (_m *Repository) UnassignParentGroup(ctx context.Context, parentGroupID string, groupIDs ...string) error {
|
||||
ret := _m.Called(ctx, parentGroupID, groupIDs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UnassignParentGroup")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok {
|
||||
r0 = rf(ctx, parentGroupID, groupIDs...)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: ctx, g
|
||||
func (_m *Repository) Update(ctx context.Context, g groups.Group) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, g)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Update")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) (groups.Group, error)); ok {
|
||||
return rf(ctx, g)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, groups.Group) groups.Group); ok {
|
||||
r0 = rf(ctx, g)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, groups.Group) error); ok {
|
||||
r1 = rf(ctx, g)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewRepository(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Repository {
|
||||
mock := &Repository{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
authn "github.com/absmach/magistrala/pkg/authn"
|
||||
|
||||
groups "github.com/absmach/magistrala/pkg/groups"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Assign provides a mock function with given fields: ctx, session, groupID, relation, memberKind, memberIDs
|
||||
func (_m *Service) Assign(ctx context.Context, session authn.Session, groupID string, relation string, memberKind string, memberIDs ...string) error {
|
||||
ret := _m.Called(ctx, session, groupID, relation, memberKind, memberIDs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Assign")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, ...string) error); ok {
|
||||
r0 = rf(ctx, session, groupID, relation, memberKind, memberIDs...)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// CreateGroup provides a mock function with given fields: ctx, session, kind, g
|
||||
func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, kind string, g groups.Group) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, session, kind, g)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CreateGroup")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.Group) (groups.Group, error)); ok {
|
||||
return rf(ctx, session, kind, g)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, groups.Group) groups.Group); ok {
|
||||
r0 = rf(ctx, session, kind, g)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, groups.Group) error); ok {
|
||||
r1 = rf(ctx, session, kind, g)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// DeleteGroup provides a mock function with given fields: ctx, session, id
|
||||
func (_m *Service) DeleteGroup(ctx context.Context, session authn.Session, id string) error {
|
||||
ret := _m.Called(ctx, session, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeleteGroup")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
|
||||
r0 = rf(ctx, session, id)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DisableGroup provides a mock function with given fields: ctx, session, id
|
||||
func (_m *Service) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, session, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DisableGroup")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
|
||||
return rf(ctx, session, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
|
||||
r0 = rf(ctx, session, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
|
||||
r1 = rf(ctx, session, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// EnableGroup provides a mock function with given fields: ctx, session, id
|
||||
func (_m *Service) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, session, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for EnableGroup")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
|
||||
return rf(ctx, session, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
|
||||
r0 = rf(ctx, session, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
|
||||
r1 = rf(ctx, session, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListGroups provides a mock function with given fields: ctx, session, memberKind, memberID, gm
|
||||
func (_m *Service) ListGroups(ctx context.Context, session authn.Session, memberKind string, memberID string, gm groups.Page) (groups.Page, error) {
|
||||
ret := _m.Called(ctx, session, memberKind, memberID, gm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListGroups")
|
||||
}
|
||||
|
||||
var r0 groups.Page
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, groups.Page) (groups.Page, error)); ok {
|
||||
return rf(ctx, session, memberKind, memberID, gm)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, groups.Page) groups.Page); ok {
|
||||
r0 = rf(ctx, session, memberKind, memberID, gm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Page)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, groups.Page) error); ok {
|
||||
r1 = rf(ctx, session, memberKind, memberID, gm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListMembers provides a mock function with given fields: ctx, session, groupID, permission, memberKind
|
||||
func (_m *Service) ListMembers(ctx context.Context, session authn.Session, groupID string, permission string, memberKind string) (groups.MembersPage, error) {
|
||||
ret := _m.Called(ctx, session, groupID, permission, memberKind)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListMembers")
|
||||
}
|
||||
|
||||
var r0 groups.MembersPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) (groups.MembersPage, error)); ok {
|
||||
return rf(ctx, session, groupID, permission, memberKind)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) groups.MembersPage); ok {
|
||||
r0 = rf(ctx, session, groupID, permission, memberKind)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.MembersPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, string) error); ok {
|
||||
r1 = rf(ctx, session, groupID, permission, memberKind)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Unassign provides a mock function with given fields: ctx, session, groupID, relation, memberKind, memberIDs
|
||||
func (_m *Service) Unassign(ctx context.Context, session authn.Session, groupID string, relation string, memberKind string, memberIDs ...string) error {
|
||||
ret := _m.Called(ctx, session, groupID, relation, memberKind, memberIDs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Unassign")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, ...string) error); ok {
|
||||
r0 = rf(ctx, session, groupID, relation, memberKind, memberIDs...)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// UpdateGroup provides a mock function with given fields: ctx, session, g
|
||||
func (_m *Service) UpdateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, session, g)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateGroup")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) (groups.Group, error)); ok {
|
||||
return rf(ctx, session, g)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) groups.Group); ok {
|
||||
r0 = rf(ctx, session, g)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, groups.Group) error); ok {
|
||||
r1 = rf(ctx, session, g)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ViewGroup provides a mock function with given fields: ctx, session, id
|
||||
func (_m *Service) ViewGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
|
||||
ret := _m.Called(ctx, session, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ViewGroup")
|
||||
}
|
||||
|
||||
var r0 groups.Group
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) (groups.Group, error)); ok {
|
||||
return rf(ctx, session, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) groups.Group); ok {
|
||||
r0 = rf(ctx, session, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(groups.Group)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
|
||||
r1 = rf(ctx, session, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ViewGroupPerms provides a mock function with given fields: ctx, session, id
|
||||
func (_m *Service) ViewGroupPerms(ctx context.Context, session authn.Session, id string) ([]string, error) {
|
||||
ret := _m.Called(ctx, session, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ViewGroupPerms")
|
||||
}
|
||||
|
||||
var r0 []string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) ([]string, error)); ok {
|
||||
return rf(ctx, session, id)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) []string); ok {
|
||||
r0 = rf(ctx, session, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok {
|
||||
r1 = rf(ctx, session, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package groups
|
||||
|
||||
// PageMeta contains page metadata that helps navigation.
|
||||
type PageMeta struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
DomainID string `json:"domain_id,omitempty"`
|
||||
Tag string `json:"tag,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package groups
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
)
|
||||
|
||||
// Status represents User status.
|
||||
type Status uint8
|
||||
|
||||
// Possible User status values.
|
||||
const (
|
||||
// EnabledStatus represents enabled User.
|
||||
EnabledStatus Status = iota
|
||||
// DisabledStatus represents disabled User.
|
||||
DisabledStatus
|
||||
// DeletedStatus represents a user that will be deleted.
|
||||
DeletedStatus
|
||||
|
||||
// AllStatus is used for querying purposes to list users irrespective
|
||||
// of their status - both enabled and disabled. It is never stored in the
|
||||
// database as the actual User status and should always be the largest
|
||||
// value in this enumeration.
|
||||
AllStatus
|
||||
)
|
||||
|
||||
// String representation of the possible status values.
|
||||
const (
|
||||
Disabled = "disabled"
|
||||
Enabled = "enabled"
|
||||
Deleted = "deleted"
|
||||
All = "all"
|
||||
Unknown = "unknown"
|
||||
)
|
||||
|
||||
// String converts user/group status to string literal.
|
||||
func (s Status) String() string {
|
||||
switch s {
|
||||
case DisabledStatus:
|
||||
return Disabled
|
||||
case EnabledStatus:
|
||||
return Enabled
|
||||
case DeletedStatus:
|
||||
return Deleted
|
||||
case AllStatus:
|
||||
return All
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
}
|
||||
|
||||
// ToStatus converts string value to a valid User/Group status.
|
||||
func ToStatus(status string) (Status, error) {
|
||||
switch status {
|
||||
case "", Enabled:
|
||||
return EnabledStatus, nil
|
||||
case Disabled:
|
||||
return DisabledStatus, nil
|
||||
case Deleted:
|
||||
return DeletedStatus, nil
|
||||
case All:
|
||||
return AllStatus, nil
|
||||
}
|
||||
return Status(0), svcerr.ErrInvalidStatus
|
||||
}
|
||||
|
||||
// Custom Marshaller for Uesr/Groups.
|
||||
func (s Status) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(s.String())
|
||||
}
|
||||
|
||||
// Custom Unmarshaler for User/Groups.
|
||||
func (s *Status) UnmarshalJSON(data []byte) error {
|
||||
str := strings.Trim(string(data), "\"")
|
||||
val, err := ToStatus(str)
|
||||
*s = val
|
||||
return err
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
domainsgrpc "github.com/absmach/magistrala/auth/api/grpc/domains"
|
||||
tokengrpc "github.com/absmach/magistrala/auth/api/grpc/token"
|
||||
thingsauth "github.com/absmach/magistrala/things/api/grpc"
|
||||
grpchealth "google.golang.org/grpc/health/grpc_health_v1"
|
||||
)
|
||||
|
||||
// SetupTokenClient loads auth services token gRPC configuration and creates new Token services gRPC client.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// tokenClient, tokenHandler, err := grpcclient.SetupTokenClient(ctx, grpcclient.Config{}).
|
||||
func SetupTokenClient(ctx context.Context, cfg Config) (magistrala.TokenServiceClient, Handler, error) {
|
||||
client, err := NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "auth",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, ErrSvcNotServing
|
||||
}
|
||||
|
||||
return tokengrpc.NewTokenClient(client.Connection(), cfg.Timeout), client, nil
|
||||
}
|
||||
|
||||
// SetupDomiansClient loads domains gRPC configuration and creates a new domains gRPC client.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// domainsClient, domainsHandler, err := grpcclient.SetupDomainsClient(ctx, grpcclient.Config{}).
|
||||
func SetupDomainsClient(ctx context.Context, cfg Config) (magistrala.DomainsServiceClient, Handler, error) {
|
||||
client, err := NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "auth",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, ErrSvcNotServing
|
||||
}
|
||||
|
||||
return domainsgrpc.NewDomainsClient(client.Connection(), cfg.Timeout), client, nil
|
||||
}
|
||||
|
||||
// SetupThingsClient loads things gRPC configuration and creates new things gRPC client.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// thingClient, thingHandler, err := grpcclient.SetupThings(ctx, grpcclient.Config{}).
|
||||
func SetupThingsClient(ctx context.Context, cfg Config) (magistrala.ThingsServiceClient, Handler, error) {
|
||||
client, err := NewHandler(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
health := grpchealth.NewHealthClient(client.Connection())
|
||||
resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{
|
||||
Service: "things",
|
||||
})
|
||||
if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING {
|
||||
return nil, nil, ErrSvcNotServing
|
||||
}
|
||||
|
||||
return thingsauth.NewClient(client.Connection(), cfg.Timeout), client, nil
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpcclient_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
domainsgrpcapi "github.com/absmach/magistrala/auth/api/grpc/domains"
|
||||
tokengrpcapi "github.com/absmach/magistrala/auth/api/grpc/token"
|
||||
"github.com/absmach/magistrala/auth/mocks"
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/pkg/grpcclient"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
grpcserver "github.com/absmach/magistrala/pkg/server/grpc"
|
||||
thingsgrpcapi "github.com/absmach/magistrala/things/api/grpc"
|
||||
thmocks "github.com/absmach/magistrala/things/mocks"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestSetupToken(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
registerAuthServiceServer := func(srv *grpc.Server) {
|
||||
magistrala.RegisterTokenServiceServer(srv, tokengrpcapi.NewTokenServer(new(mocks.Service)))
|
||||
}
|
||||
gs := grpcserver.NewServer(ctx, cancel, "auth", server.Config{Port: "12345"}, registerAuthServiceServer, mglog.NewMock())
|
||||
go func() {
|
||||
err := gs.Start()
|
||||
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error creating server %s"`, err))
|
||||
}()
|
||||
defer func() {
|
||||
err := gs.Stop()
|
||||
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error stopping server %s"`, err))
|
||||
}()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config grpcclient.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "successful",
|
||||
config: grpcclient.Config{
|
||||
URL: "localhost:12345",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed with empty URL",
|
||||
config: grpcclient.Config{
|
||||
URL: "",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: errors.New("service is not serving"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
client, handler, err := grpcclient.SetupTokenClient(context.Background(), c.config)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
|
||||
if err == nil {
|
||||
assert.NotNil(t, client)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupThingsClient(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
registerThingsServiceServer := func(srv *grpc.Server) {
|
||||
magistrala.RegisterThingsServiceServer(srv, thingsgrpcapi.NewServer(new(thmocks.Service)))
|
||||
}
|
||||
gs := grpcserver.NewServer(ctx, cancel, "things", server.Config{Port: "12345"}, registerThingsServiceServer, mglog.NewMock())
|
||||
go func() {
|
||||
err := gs.Start()
|
||||
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error creating server %s"`, err))
|
||||
}()
|
||||
defer func() {
|
||||
err := gs.Stop()
|
||||
assert.Nil(t, err, fmt.Sprintf(`"Unexpected error stopping server %s"`, err))
|
||||
}()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config grpcclient.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "successful",
|
||||
config: grpcclient.Config{
|
||||
URL: "localhost:12345",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed with empty URL",
|
||||
config: grpcclient.Config{
|
||||
URL: "",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: errors.New("service is not serving"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
client, handler, err := grpcclient.SetupThingsClient(context.Background(), c.config)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
|
||||
if err == nil {
|
||||
assert.NotNil(t, client)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetupDomainsClient(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
registerDomainsServiceServer := func(srv *grpc.Server) {
|
||||
magistrala.RegisterDomainsServiceServer(srv, domainsgrpcapi.NewDomainsServer(new(mocks.Service)))
|
||||
}
|
||||
gs := grpcserver.NewServer(ctx, cancel, "auth", server.Config{Port: "12345"}, registerDomainsServiceServer, mglog.NewMock())
|
||||
go func() {
|
||||
err := gs.Start()
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating server %s", err))
|
||||
}()
|
||||
defer func() {
|
||||
err := gs.Stop()
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error stopping server %s", err))
|
||||
}()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config grpcclient.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "successfully",
|
||||
config: grpcclient.Config{
|
||||
URL: "localhost:12345",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed with empty URL",
|
||||
config: grpcclient.Config{
|
||||
URL: "",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: errors.New("service is not serving"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
client, handler, err := grpcclient.SetupDomainsClient(context.Background(), c.config)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
|
||||
if err == nil {
|
||||
assert.NotNil(t, client)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,153 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
type security int
|
||||
|
||||
const (
|
||||
withoutTLS security = iota
|
||||
withTLS
|
||||
withmTLS
|
||||
)
|
||||
const buffSize = 10 * 1024 * 1024
|
||||
|
||||
var (
|
||||
errGrpcConnect = errors.New("failed to connect to grpc server")
|
||||
errGrpcClose = errors.New("failed to close grpc connection")
|
||||
ErrSvcNotServing = errors.New("service is not serving")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
URL string `env:"URL" envDefault:""`
|
||||
Timeout time.Duration `env:"TIMEOUT" envDefault:"1s"`
|
||||
ClientCert string `env:"CLIENT_CERT" envDefault:""`
|
||||
ClientKey string `env:"CLIENT_KEY" envDefault:""`
|
||||
ServerCAFile string `env:"SERVER_CA_CERTS" envDefault:""`
|
||||
}
|
||||
|
||||
// Handler is used to handle gRPC connection.
|
||||
type Handler interface {
|
||||
// Close closes gRPC connection.
|
||||
Close() error
|
||||
|
||||
// Secure is used for pretty printing TLS info.
|
||||
Secure() string
|
||||
|
||||
// Connection returns the gRPC connection.
|
||||
Connection() *grpc.ClientConn
|
||||
}
|
||||
|
||||
type client struct {
|
||||
*grpc.ClientConn
|
||||
cfg Config
|
||||
secure security
|
||||
}
|
||||
|
||||
var _ Handler = (*client)(nil)
|
||||
|
||||
func NewHandler(cfg Config) (Handler, error) {
|
||||
conn, secure, err := connect(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &client{
|
||||
ClientConn: conn,
|
||||
cfg: cfg,
|
||||
secure: secure,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
if err := c.ClientConn.Close(); err != nil {
|
||||
return errors.Wrap(errGrpcClose, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Connection() *grpc.ClientConn {
|
||||
return c.ClientConn
|
||||
}
|
||||
|
||||
// Secure is used for pretty printing TLS info.
|
||||
func (c *client) Secure() string {
|
||||
switch c.secure {
|
||||
case withTLS:
|
||||
return "with TLS"
|
||||
case withmTLS:
|
||||
return "with mTLS"
|
||||
case withoutTLS:
|
||||
fallthrough
|
||||
default:
|
||||
return "without TLS"
|
||||
}
|
||||
}
|
||||
|
||||
// connect creates new gRPC client and connect to gRPC server.
|
||||
func connect(cfg Config) (*grpc.ClientConn, security, error) {
|
||||
opts := []grpc.DialOption{
|
||||
grpc.WithStatsHandler(otelgrpc.NewClientHandler()),
|
||||
}
|
||||
secure := withoutTLS
|
||||
tc := insecure.NewCredentials()
|
||||
|
||||
if cfg.ServerCAFile != "" {
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
// Loading root ca certificates file
|
||||
rootCA, err := os.ReadFile(cfg.ServerCAFile)
|
||||
if err != nil {
|
||||
return nil, secure, fmt.Errorf("failed to load root ca file: %w", err)
|
||||
}
|
||||
if len(rootCA) > 0 {
|
||||
capool := x509.NewCertPool()
|
||||
if !capool.AppendCertsFromPEM(rootCA) {
|
||||
return nil, secure, fmt.Errorf("failed to append root ca to tls.Config")
|
||||
}
|
||||
tlsConfig.RootCAs = capool
|
||||
secure = withTLS
|
||||
}
|
||||
|
||||
// Loading mtls certificates file
|
||||
if cfg.ClientCert != "" || cfg.ClientKey != "" {
|
||||
certificate, err := tls.LoadX509KeyPair(cfg.ClientCert, cfg.ClientKey)
|
||||
if err != nil {
|
||||
return nil, secure, fmt.Errorf("failed to client certificate and key %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
secure = withmTLS
|
||||
}
|
||||
|
||||
tc = credentials.NewTLS(tlsConfig)
|
||||
}
|
||||
|
||||
opts = append(
|
||||
opts, grpc.WithTransportCredentials(tc),
|
||||
grpc.WithReadBufferSize(buffSize),
|
||||
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(buffSize/10), grpc.MaxCallSendMsgSize(buffSize/10)),
|
||||
grpc.WithWriteBufferSize(buffSize),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(cfg.URL, opts...)
|
||||
if err != nil {
|
||||
return nil, secure, errors.Wrap(errGrpcConnect, err)
|
||||
}
|
||||
|
||||
return conn, secure, nil
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpcclient
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
config Config
|
||||
err error
|
||||
secure string
|
||||
}{
|
||||
{
|
||||
desc: "successful without TLS",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
err: nil,
|
||||
secure: "without TLS",
|
||||
},
|
||||
{
|
||||
desc: "successful with TLS",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ServerCAFile: "../../docker/ssl/certs/ca.crt",
|
||||
},
|
||||
err: nil,
|
||||
secure: "with TLS",
|
||||
},
|
||||
{
|
||||
desc: "successful with mTLS",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ClientCert: "../../docker/ssl/certs/magistrala-server.crt",
|
||||
ClientKey: "../../docker/ssl/certs/magistrala-server.key",
|
||||
ServerCAFile: "../../docker/ssl/certs/ca.crt",
|
||||
},
|
||||
err: nil,
|
||||
secure: "with mTLS",
|
||||
},
|
||||
{
|
||||
desc: "failed with empty URL",
|
||||
config: Config{
|
||||
URL: "",
|
||||
Timeout: time.Second,
|
||||
},
|
||||
secure: "without TLS",
|
||||
},
|
||||
{
|
||||
desc: "failed with invalid server CA file",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ServerCAFile: "invalid",
|
||||
},
|
||||
err: errors.New("failed to load root ca file: open invalid: no such file or directory"),
|
||||
},
|
||||
{
|
||||
desc: "failed with invalid server CA file as cert key",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ServerCAFile: "../../docker/ssl/certs/magistrala-server.key",
|
||||
},
|
||||
err: errors.New("failed to append root ca to tls.Config"),
|
||||
},
|
||||
{
|
||||
desc: "failed with invalid client cert",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ClientCert: "invalid",
|
||||
ClientKey: "../../docker/ssl/certs/magistrala-server.key",
|
||||
ServerCAFile: "../../docker/ssl/certs/ca.crt",
|
||||
},
|
||||
err: errors.New("failed to client certificate and key open invalid: no such file or directory"),
|
||||
},
|
||||
{
|
||||
desc: "failed with invalid client key",
|
||||
config: Config{
|
||||
URL: "localhost:8080",
|
||||
Timeout: time.Second,
|
||||
ClientCert: "../../docker/ssl/certs/magistrala-server.crt",
|
||||
ClientKey: "invalid",
|
||||
ServerCAFile: "../../docker/ssl/certs/ca.crt",
|
||||
},
|
||||
err: errors.New("failed to client certificate and key open invalid: no such file or directory"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
handler, err := NewHandler(c.config)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s", err, c.err))
|
||||
if err == nil {
|
||||
assert.Equal(t, c.secure, handler.Secure())
|
||||
assert.NotNil(t, handler.Connection())
|
||||
assert.Nil(t, handler.Close())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package auth contains the domain concept definitions needed to support
|
||||
// Magistrala auth functionality.
|
||||
package grpcclient
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package jaeger contains the domain concept definitions needed to support
|
||||
// Magistrala Jaeger tracing functionality.
|
||||
package jaeger
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package jaeger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
"go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
|
||||
)
|
||||
|
||||
var (
|
||||
errNoURL = errors.New("URL is empty")
|
||||
errNoSvcName = errors.New("service Name is empty")
|
||||
errUnsupportedTraceURLScheme = errors.New("unsupported tracing url scheme")
|
||||
)
|
||||
|
||||
// NewProvider initializes Jaeger TraceProvider.
|
||||
//
|
||||
// tp, err := jaeger.NewProvider(ctx, "demo-service", "http://localhost:14268/api/traces", "2cb32911-6833-469c-9cad-4d3e93c528d8", "1.0")
|
||||
func NewProvider(ctx context.Context, svcName string, jaegerUrl url.URL, instanceID string, fraction float64) (*trace.TracerProvider, error) {
|
||||
if jaegerUrl == (url.URL{}) {
|
||||
return nil, errNoURL
|
||||
}
|
||||
|
||||
if svcName == "" {
|
||||
return nil, errNoSvcName
|
||||
}
|
||||
|
||||
var client otlptrace.Client
|
||||
switch jaegerUrl.Scheme {
|
||||
case "http":
|
||||
client = otlptracehttp.NewClient(otlptracehttp.WithEndpoint(jaegerUrl.Host), otlptracehttp.WithURLPath(jaegerUrl.Path), otlptracehttp.WithInsecure())
|
||||
case "https":
|
||||
client = otlptracehttp.NewClient(otlptracehttp.WithEndpoint(jaegerUrl.Host), otlptracehttp.WithURLPath(jaegerUrl.Path))
|
||||
default:
|
||||
return nil, errUnsupportedTraceURLScheme
|
||||
}
|
||||
|
||||
exporter, err := otlptrace.New(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attributes := []attribute.KeyValue{
|
||||
semconv.ServiceNameKey.String(svcName),
|
||||
attribute.String("host.id", instanceID),
|
||||
}
|
||||
|
||||
hostAttr, err := resource.New(ctx, resource.WithHost(), resource.WithOSDescription(), resource.WithContainer())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attributes = append(attributes, hostAttr.Attributes()...)
|
||||
|
||||
tp := trace.NewTracerProvider(
|
||||
trace.WithSampler(trace.TraceIDRatioBased(fraction)),
|
||||
trace.WithBatcher(exporter),
|
||||
trace.WithResource(resource.NewWithAttributes(
|
||||
semconv.SchemaURL,
|
||||
attributes...,
|
||||
)),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
otel.SetTextMapPropagator(propagation.TraceContext{})
|
||||
|
||||
return tp, nil
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
# Messaging
|
||||
|
||||
`messaging` package defines `Publisher`, `Subscriber` and an aggregate `Pubsub` interface.
|
||||
|
||||
`Subscriber` interface defines methods used to subscribe to a message broker such as MQTT or NATS or RabbitMQ.
|
||||
|
||||
`Publisher` interface defines methods used to publish messages to a message broker such as MQTT or NATS or RabbitMQ.
|
||||
|
||||
`Pubsub` interface is composed of `Publisher` and `Subscriber` interface and can be used to send messages to as well as to receive messages from a message broker.
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !rabbitmq
|
||||
// +build !rabbitmq
|
||||
|
||||
package brokers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/nats"
|
||||
)
|
||||
|
||||
// SubjectAllChannels represents subject to subscribe for all the channels.
|
||||
const SubjectAllChannels = "channels.>"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using Nats as the message broker")
|
||||
}
|
||||
|
||||
func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
|
||||
pb, err := nats.NewPublisher(ctx, url, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
|
||||
func NewPubSub(ctx context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
|
||||
pb, err := nats.NewPubSub(ctx, url, logger, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build rabbitmq
|
||||
// +build rabbitmq
|
||||
|
||||
package brokers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
|
||||
)
|
||||
|
||||
// SubjectAllChannels represents subject to subscribe for all the channels.
|
||||
const SubjectAllChannels = "channels.#"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using RabbitMQ as the message broker")
|
||||
}
|
||||
|
||||
func NewPublisher(_ context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
|
||||
pb, err := rabbitmq.NewPublisher(url, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
|
||||
func NewPubSub(_ context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
|
||||
pb, err := rabbitmq.NewPubSub(url, logger, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pb, nil
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !rabbitmq
|
||||
// +build !rabbitmq
|
||||
|
||||
package brokers
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/nats/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// SubjectAllChannels represents subject to subscribe for all the channels.
|
||||
const SubjectAllChannels = "channels.>"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using Nats as the message broker")
|
||||
}
|
||||
|
||||
func NewPublisher(cfg server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
|
||||
return tracing.NewPublisher(cfg, tracer, publisher)
|
||||
}
|
||||
|
||||
func NewPubSub(cfg server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
|
||||
return tracing.NewPubSub(cfg, tracer, pubsub)
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build rabbitmq
|
||||
// +build rabbitmq
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package brokers
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/rabbitmq/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// SubjectAllChannels represents subject to subscribe for all the channels.
|
||||
const SubjectAllChannels = "channels.#"
|
||||
|
||||
func init() {
|
||||
log.Println("The binary was build using RabbitMQ as the message broker")
|
||||
}
|
||||
|
||||
func NewPublisher(cfg server.Config, tracer trace.Tracer, pub messaging.Publisher) messaging.Publisher {
|
||||
return tracing.NewPublisher(cfg, tracer, pub)
|
||||
}
|
||||
|
||||
func NewPubSub(cfg server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
|
||||
return tracing.NewPubSub(cfg, tracer, pubsub)
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
)
|
||||
|
||||
var _ session.Handler = (*loggingMiddleware)(nil)
|
||||
|
||||
type loggingMiddleware struct {
|
||||
logger *slog.Logger
|
||||
svc session.Handler
|
||||
}
|
||||
|
||||
// AuthConnect implements session.Handler.
|
||||
func (lm *loggingMiddleware) AuthConnect(ctx context.Context) (err error) {
|
||||
defer lm.logAction("AuthConnect", nil, time.Now(), err)
|
||||
return lm.svc.AuthConnect(ctx)
|
||||
}
|
||||
|
||||
// AuthPublish implements session.Handler.
|
||||
func (lm *loggingMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) (err error) {
|
||||
defer lm.logAction("AuthPublish", &[]string{*topic}, time.Now(), err)
|
||||
return lm.svc.AuthPublish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// AuthSubscribe implements session.Handler.
|
||||
func (lm *loggingMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) (err error) {
|
||||
defer lm.logAction("AuthSubscribe", topics, time.Now(), err)
|
||||
return lm.svc.AuthSubscribe(ctx, topics)
|
||||
}
|
||||
|
||||
// Connect implements session.Handler.
|
||||
func (lm *loggingMiddleware) Connect(ctx context.Context) (err error) {
|
||||
defer lm.logAction("Connect", nil, time.Now(), err)
|
||||
return lm.svc.Connect(ctx)
|
||||
}
|
||||
|
||||
// Disconnect implements session.Handler.
|
||||
func (lm *loggingMiddleware) Disconnect(ctx context.Context) (err error) {
|
||||
defer lm.logAction("Disconnect", nil, time.Now(), err)
|
||||
return lm.svc.Disconnect(ctx)
|
||||
}
|
||||
|
||||
// Publish logs the publish request. It logs the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) (err error) {
|
||||
defer lm.logAction("Publish", &[]string{*topic}, time.Now(), err)
|
||||
return lm.svc.Publish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// Subscribe implements session.Handler.
|
||||
func (lm *loggingMiddleware) Subscribe(ctx context.Context, topics *[]string) (err error) {
|
||||
defer lm.logAction("Subscribe", topics, time.Now(), err)
|
||||
return lm.svc.Subscribe(ctx, topics)
|
||||
}
|
||||
|
||||
// Unsubscribe implements session.Handler.
|
||||
func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, topics *[]string) (err error) {
|
||||
defer lm.logAction("Unsubscribe", topics, time.Now(), err)
|
||||
return lm.svc.Unsubscribe(ctx, topics)
|
||||
}
|
||||
|
||||
// LoggingMiddleware adds logging facilities to the adapter.
|
||||
func LoggingMiddleware(svc session.Handler, logger *slog.Logger) session.Handler {
|
||||
return &loggingMiddleware{logger, svc}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) logAction(action string, topics *[]string, t time.Time, err error) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(t).String()),
|
||||
}
|
||||
if topics != nil {
|
||||
args = append(args, slog.Any("topics", *topics))
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn(action+" failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info(action+" completed successfully", args...)
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
)
|
||||
|
||||
var _ session.Handler = (*metricsMiddleware)(nil)
|
||||
|
||||
type metricsMiddleware struct {
|
||||
counter metrics.Counter
|
||||
latency metrics.Histogram
|
||||
svc session.Handler
|
||||
}
|
||||
|
||||
// MetricsMiddleware instruments adapter by tracking request count and latency.
|
||||
func MetricsMiddleware(svc session.Handler, counter metrics.Counter, latency metrics.Histogram) session.Handler {
|
||||
return &metricsMiddleware{
|
||||
counter: counter,
|
||||
latency: latency,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthConnect implements session.Handler.
|
||||
func (mm *metricsMiddleware) AuthConnect(ctx context.Context) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "publish").Add(1)
|
||||
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.AuthConnect(ctx)
|
||||
}
|
||||
|
||||
// AuthPublish implements session.Handler.
|
||||
func (mm *metricsMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "publish").Add(1)
|
||||
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.AuthPublish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// AuthSubscribe implements session.Handler.
|
||||
func (*metricsMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect implements session.Handler.
|
||||
func (*metricsMiddleware) Connect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect implements session.Handler.
|
||||
func (*metricsMiddleware) Disconnect(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish instruments Publish method with metrics.
|
||||
func (mm *metricsMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "publish").Add(1)
|
||||
mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Publish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// Subscribe implements session.Handler.
|
||||
func (*metricsMiddleware) Subscribe(ctx context.Context, topics *[]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe implements session.Handler.
|
||||
func (*metricsMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/mgate/pkg/session"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const (
|
||||
authConnectOP = "auth_connect_op"
|
||||
authPublishOP = "auth_publish_op"
|
||||
authSubscribeOP = "auth_subscribe_op"
|
||||
connectOP = "connect_op"
|
||||
disconnectOP = "disconnect_op"
|
||||
subscribeOP = "subscribe_op"
|
||||
unsubscribeOP = "unsubscribe_op"
|
||||
publishOP = "publish_op"
|
||||
)
|
||||
|
||||
var _ session.Handler = (*handlerMiddleware)(nil)
|
||||
|
||||
type handlerMiddleware struct {
|
||||
handler session.Handler
|
||||
tracer trace.Tracer
|
||||
}
|
||||
|
||||
// NewHandler creates a new session.Handler middleware with tracing.
|
||||
func NewTracing(tracer trace.Tracer, handler session.Handler) session.Handler {
|
||||
return &handlerMiddleware{
|
||||
tracer: tracer,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthConnect traces auth connect operations.
|
||||
func (h *handlerMiddleware) AuthConnect(ctx context.Context) error {
|
||||
kvOpts := []attribute.KeyValue{}
|
||||
s, ok := session.FromContext(ctx)
|
||||
if ok {
|
||||
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
|
||||
kvOpts = append(kvOpts, attribute.String("username", s.Username))
|
||||
}
|
||||
ctx, span := h.tracer.Start(ctx, authConnectOP, trace.WithAttributes(kvOpts...))
|
||||
defer span.End()
|
||||
return h.handler.AuthConnect(ctx)
|
||||
}
|
||||
|
||||
// AuthPublish traces auth publish operations.
|
||||
func (h *handlerMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error {
|
||||
kvOpts := []attribute.KeyValue{}
|
||||
s, ok := session.FromContext(ctx)
|
||||
if ok {
|
||||
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
|
||||
if topic != nil {
|
||||
kvOpts = append(kvOpts, attribute.String("topic", *topic))
|
||||
}
|
||||
}
|
||||
ctx, span := h.tracer.Start(ctx, authPublishOP, trace.WithAttributes(kvOpts...))
|
||||
defer span.End()
|
||||
return h.handler.AuthPublish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// AuthSubscribe traces auth subscribe operations.
|
||||
func (h *handlerMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error {
|
||||
kvOpts := []attribute.KeyValue{}
|
||||
s, ok := session.FromContext(ctx)
|
||||
if ok {
|
||||
kvOpts = append(kvOpts, attribute.String("client_id", s.ID))
|
||||
if topics != nil {
|
||||
kvOpts = append(kvOpts, attribute.StringSlice("topics", *topics))
|
||||
}
|
||||
}
|
||||
ctx, span := h.tracer.Start(ctx, authSubscribeOP, trace.WithAttributes(kvOpts...))
|
||||
defer span.End()
|
||||
return h.handler.AuthSubscribe(ctx, topics)
|
||||
}
|
||||
|
||||
// Connect traces connect operations.
|
||||
func (h *handlerMiddleware) Connect(ctx context.Context) error {
|
||||
ctx, span := h.tracer.Start(ctx, connectOP)
|
||||
defer span.End()
|
||||
return h.handler.Connect(ctx)
|
||||
}
|
||||
|
||||
// Disconnect traces disconnect operations.
|
||||
func (h *handlerMiddleware) Disconnect(ctx context.Context) error {
|
||||
ctx, span := h.tracer.Start(ctx, disconnectOP)
|
||||
defer span.End()
|
||||
return h.handler.Disconnect(ctx)
|
||||
}
|
||||
|
||||
// Publish traces publish operations.
|
||||
func (h *handlerMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error {
|
||||
ctx, span := h.tracer.Start(ctx, publishOP)
|
||||
defer span.End()
|
||||
return h.handler.Publish(ctx, topic, payload)
|
||||
}
|
||||
|
||||
// Subscribe traces subscribe operations.
|
||||
func (h *handlerMiddleware) Subscribe(ctx context.Context, topics *[]string) error {
|
||||
ctx, span := h.tracer.Start(ctx, subscribeOP)
|
||||
defer span.End()
|
||||
return h.handler.Subscribe(ctx, topics)
|
||||
}
|
||||
|
||||
// Unsubscribe traces unsubscribe operations.
|
||||
func (h *handlerMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error {
|
||||
ctx, span := h.tracer.Start(ctx, unsubscribeOP)
|
||||
defer span.End()
|
||||
return h.handler.Unsubscribe(ctx, topics)
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.34.2
|
||||
// protoc v5.27.1
|
||||
// source: pkg/messaging/message.proto
|
||||
|
||||
package messaging
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// Message represents a message emitted by the Magistrala adapters layer.
|
||||
type Message struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Channel string `protobuf:"bytes,1,opt,name=channel,proto3" json:"channel,omitempty"`
|
||||
Subtopic string `protobuf:"bytes,2,opt,name=subtopic,proto3" json:"subtopic,omitempty"`
|
||||
Publisher string `protobuf:"bytes,3,opt,name=publisher,proto3" json:"publisher,omitempty"`
|
||||
Protocol string `protobuf:"bytes,4,opt,name=protocol,proto3" json:"protocol,omitempty"`
|
||||
Payload []byte `protobuf:"bytes,5,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
Created int64 `protobuf:"varint,6,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds
|
||||
}
|
||||
|
||||
func (x *Message) Reset() {
|
||||
*x = Message{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_pkg_messaging_message_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Message) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Message) ProtoMessage() {}
|
||||
|
||||
func (x *Message) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_pkg_messaging_message_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
|
||||
func (*Message) Descriptor() ([]byte, []int) {
|
||||
return file_pkg_messaging_message_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Message) GetChannel() string {
|
||||
if x != nil {
|
||||
return x.Channel
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetSubtopic() string {
|
||||
if x != nil {
|
||||
return x.Subtopic
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetPublisher() string {
|
||||
if x != nil {
|
||||
return x.Publisher
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetProtocol() string {
|
||||
if x != nil {
|
||||
return x.Protocol
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetPayload() []byte {
|
||||
if x != nil {
|
||||
return x.Payload
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Message) GetCreated() int64 {
|
||||
if x != nil {
|
||||
return x.Created
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
var File_pkg_messaging_message_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_pkg_messaging_message_proto_rawDesc = []byte{
|
||||
0x0a, 0x1b, 0x70, 0x6b, 0x67, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x2f,
|
||||
0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x09, 0x6d,
|
||||
0x65, 0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x22, 0xad, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73,
|
||||
0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x12, 0x1a,
|
||||
0x0a, 0x08, 0x73, 0x75, 0x62, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x08, 0x73, 0x75, 0x62, 0x74, 0x6f, 0x70, 0x69, 0x63, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x75,
|
||||
0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70,
|
||||
0x75, 0x62, 0x6c, 0x69, 0x73, 0x68, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18,
|
||||
0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x18,
|
||||
0x0a, 0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52,
|
||||
0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x42, 0x0d, 0x5a, 0x0b, 0x2e, 0x2f, 0x6d, 0x65,
|
||||
0x73, 0x73, 0x61, 0x67, 0x69, 0x6e, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_pkg_messaging_message_proto_rawDescOnce sync.Once
|
||||
file_pkg_messaging_message_proto_rawDescData = file_pkg_messaging_message_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_pkg_messaging_message_proto_rawDescGZIP() []byte {
|
||||
file_pkg_messaging_message_proto_rawDescOnce.Do(func() {
|
||||
file_pkg_messaging_message_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_messaging_message_proto_rawDescData)
|
||||
})
|
||||
return file_pkg_messaging_message_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_pkg_messaging_message_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_pkg_messaging_message_proto_goTypes = []any{
|
||||
(*Message)(nil), // 0: messaging.Message
|
||||
}
|
||||
var file_pkg_messaging_message_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
0, // [0:0] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_pkg_messaging_message_proto_init() }
|
||||
func file_pkg_messaging_message_proto_init() {
|
||||
if File_pkg_messaging_message_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_pkg_messaging_message_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||
switch v := v.(*Message); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_pkg_messaging_message_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_pkg_messaging_message_proto_goTypes,
|
||||
DependencyIndexes: file_pkg_messaging_message_proto_depIdxs,
|
||||
MessageInfos: file_pkg_messaging_message_proto_msgTypes,
|
||||
}.Build()
|
||||
File_pkg_messaging_message_proto = out.File
|
||||
file_pkg_messaging_message_proto_rawDesc = nil
|
||||
file_pkg_messaging_message_proto_goTypes = nil
|
||||
file_pkg_messaging_message_proto_depIdxs = nil
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
syntax = "proto3";
|
||||
package messaging;
|
||||
|
||||
option go_package = "./messaging";
|
||||
|
||||
// Message represents a message emitted by the Magistrala adapters layer.
|
||||
message Message {
|
||||
string channel = 1;
|
||||
string subtopic = 2;
|
||||
string publisher = 3;
|
||||
string protocol = 4;
|
||||
bytes payload = 5;
|
||||
int64 created = 6; // Unix timestamp in nanoseconds
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
messaging "github.com/absmach/magistrala/pkg/messaging"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// PubSub is an autogenerated mock type for the PubSub type
|
||||
type PubSub struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Close provides a mock function with given fields:
|
||||
func (_m *PubSub) Close() error {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Close")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func() error); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Publish provides a mock function with given fields: ctx, topic, msg
|
||||
func (_m *PubSub) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
ret := _m.Called(ctx, topic, msg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Publish")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, *messaging.Message) error); ok {
|
||||
r0 = rf(ctx, topic, msg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Subscribe provides a mock function with given fields: ctx, cfg
|
||||
func (_m *PubSub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
ret := _m.Called(ctx, cfg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Subscribe")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, messaging.SubscriberConfig) error); ok {
|
||||
r0 = rf(ctx, cfg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Unsubscribe provides a mock function with given fields: ctx, id, topic
|
||||
func (_m *PubSub) Unsubscribe(ctx context.Context, id string, topic string) error {
|
||||
ret := _m.Called(ctx, id, topic)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Unsubscribe")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = rf(ctx, id, topic)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewPubSub creates a new instance of PubSub. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewPubSub(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *PubSub {
|
||||
mock := &PubSub{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package mqtt hold the implementation of the Publisher and PubSub
|
||||
// interfaces for the MQTT messaging system, the internal messaging
|
||||
// broker of the Magistrala IoT platform. Due to the practical requirements
|
||||
// implementation Publisher is created alongside PubSub. The reason for
|
||||
// this is that Subscriber implementation of MQTT brings the burden of
|
||||
// additional struct fields which are not used by Publisher. Subscriber
|
||||
// is not implemented separately because PubSub can be used where Subscriber is needed.
|
||||
package mqtt
|
||||
@@ -1,61 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
)
|
||||
|
||||
var errPublishTimeout = errors.New("failed to publish due to timeout reached")
|
||||
|
||||
var _ messaging.Publisher = (*publisher)(nil)
|
||||
|
||||
type publisher struct {
|
||||
client mqtt.Client
|
||||
timeout time.Duration
|
||||
qos uint8
|
||||
}
|
||||
|
||||
// NewPublisher returns a new MQTT message publisher.
|
||||
func NewPublisher(address string, qos uint8, timeout time.Duration) (messaging.Publisher, error) {
|
||||
client, err := newClient(address, "mqtt-publisher", timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := publisher{
|
||||
client: client,
|
||||
timeout: timeout,
|
||||
qos: qos,
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (pub publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
|
||||
// Publish only the payload and not the whole message.
|
||||
token := pub.client.Publish(topic, byte(pub.qos), false, msg.GetPayload())
|
||||
if token.Error() != nil {
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
if ok := token.WaitTimeout(pub.timeout); !ok {
|
||||
return errPublishTimeout
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pub publisher) Close() error {
|
||||
pub.client.Disconnect(uint(pub.timeout))
|
||||
return nil
|
||||
}
|
||||
@@ -1,230 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const username = "magistrala-mqtt"
|
||||
|
||||
var (
|
||||
// ErrConnect indicates that connection to MQTT broker failed.
|
||||
ErrConnect = errors.New("failed to connect to MQTT broker")
|
||||
|
||||
// errSubscribeTimeout indicates that the subscription failed due to timeout.
|
||||
errSubscribeTimeout = errors.New("failed to subscribe due to timeout reached")
|
||||
|
||||
// errUnsubscribeTimeout indicates that unsubscribe failed due to timeout.
|
||||
errUnsubscribeTimeout = errors.New("failed to unsubscribe due to timeout reached")
|
||||
|
||||
// errUnsubscribeDeleteTopic indicates that unsubscribe failed because the topic was deleted.
|
||||
errUnsubscribeDeleteTopic = errors.New("failed to unsubscribe due to deletion of topic")
|
||||
|
||||
// ErrNotSubscribed indicates that the topic is not subscribed to.
|
||||
ErrNotSubscribed = errors.New("not subscribed")
|
||||
|
||||
// ErrEmptyTopic indicates the absence of topic.
|
||||
ErrEmptyTopic = errors.New("empty topic")
|
||||
|
||||
// ErrEmptyID indicates the absence of ID.
|
||||
ErrEmptyID = errors.New("empty ID")
|
||||
)
|
||||
|
||||
var _ messaging.PubSub = (*pubsub)(nil)
|
||||
|
||||
type subscription struct {
|
||||
client mqtt.Client
|
||||
topics []string
|
||||
cancel func() error
|
||||
}
|
||||
|
||||
type pubsub struct {
|
||||
publisher
|
||||
logger *slog.Logger
|
||||
mu sync.RWMutex
|
||||
address string
|
||||
timeout time.Duration
|
||||
subscriptions map[string]subscription
|
||||
}
|
||||
|
||||
// NewPubSub returns MQTT message publisher/subscriber.
|
||||
func NewPubSub(url string, qos uint8, timeout time.Duration, logger *slog.Logger) (messaging.PubSub, error) {
|
||||
client, err := newClient(url, "mqtt-publisher", timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := &pubsub{
|
||||
publisher: publisher{
|
||||
client: client,
|
||||
timeout: timeout,
|
||||
qos: qos,
|
||||
},
|
||||
address: url,
|
||||
timeout: timeout,
|
||||
logger: logger,
|
||||
subscriptions: make(map[string]subscription),
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
if cfg.ID == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if cfg.Topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
s, ok := ps.subscriptions[cfg.ID]
|
||||
// If the client exists, check if it's subscribed to the topic and unsubscribe if needed.
|
||||
switch ok {
|
||||
case true:
|
||||
if ok := s.contains(cfg.Topic); ok {
|
||||
if err := s.unsubscribe(cfg.Topic, ps.timeout); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
default:
|
||||
client, err := newClient(ps.address, cfg.ID, ps.timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s = subscription{
|
||||
client: client,
|
||||
topics: []string{},
|
||||
cancel: cfg.Handler.Cancel,
|
||||
}
|
||||
}
|
||||
s.topics = append(s.topics, cfg.Topic)
|
||||
ps.subscriptions[cfg.ID] = s
|
||||
|
||||
token := s.client.Subscribe(cfg.Topic, byte(ps.qos), ps.mqttHandler(cfg.Handler))
|
||||
if token.Error() != nil {
|
||||
return token.Error()
|
||||
}
|
||||
if ok := token.WaitTimeout(ps.timeout); !ok {
|
||||
return errSubscribeTimeout
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
if id == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
s, ok := ps.subscriptions[id]
|
||||
if !ok || !s.contains(topic) {
|
||||
return ErrNotSubscribed
|
||||
}
|
||||
|
||||
if err := s.unsubscribe(topic, ps.timeout); err != nil {
|
||||
return err
|
||||
}
|
||||
ps.subscriptions[id] = s
|
||||
|
||||
if len(s.topics) == 0 {
|
||||
delete(ps.subscriptions, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *subscription) unsubscribe(topic string, timeout time.Duration) error {
|
||||
if s.cancel != nil {
|
||||
if err := s.cancel(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
token := s.client.Unsubscribe(topic)
|
||||
if token.Error() != nil {
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
if ok := token.WaitTimeout(timeout); !ok {
|
||||
return errUnsubscribeTimeout
|
||||
}
|
||||
if ok := s.delete(topic); !ok {
|
||||
return errUnsubscribeDeleteTopic
|
||||
}
|
||||
return token.Error()
|
||||
}
|
||||
|
||||
func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) {
|
||||
opts := mqtt.NewClientOptions().
|
||||
SetUsername(username).
|
||||
AddBroker(address).
|
||||
SetClientID(id)
|
||||
client := mqtt.NewClient(opts)
|
||||
token := client.Connect()
|
||||
if token.Error() != nil {
|
||||
return nil, token.Error()
|
||||
}
|
||||
|
||||
if ok := token.WaitTimeout(timeout); !ok {
|
||||
return nil, ErrConnect
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) mqttHandler(h messaging.MessageHandler) mqtt.MessageHandler {
|
||||
return func(_ mqtt.Client, m mqtt.Message) {
|
||||
var msg messaging.Message
|
||||
if err := proto.Unmarshal(m.Payload(), &msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Handle(&msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Contains checks if a topic is present.
|
||||
func (s subscription) contains(topic string) bool {
|
||||
return s.indexOf(topic) != -1
|
||||
}
|
||||
|
||||
// Finds the index of an item in the topics.
|
||||
func (s subscription) indexOf(element string) int {
|
||||
for k, v := range s.topics {
|
||||
if element == v {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Deletes a topic from the slice.
|
||||
func (s *subscription) delete(topic string) bool {
|
||||
index := s.indexOf(topic)
|
||||
if index == -1 {
|
||||
return false
|
||||
}
|
||||
topics := make([]string, len(s.topics)-1)
|
||||
copy(topics[:index], s.topics[:index])
|
||||
copy(topics[index:], s.topics[index+1:])
|
||||
s.topics = topics
|
||||
return true
|
||||
}
|
||||
@@ -1,474 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mqtt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
mqttpubsub "github.com/absmach/magistrala/pkg/messaging/mqtt"
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
topic = "topic"
|
||||
chansPrefix = "channels"
|
||||
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
|
||||
subtopic = "engine"
|
||||
tokenTimeout = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var data = []byte("payload")
|
||||
|
||||
// ErrFailedHandleMessage indicates that the message couldn't be handled.
|
||||
var errFailedHandleMessage = errors.New("failed to handle magistrala message")
|
||||
|
||||
func TestPublisher(t *testing.T) {
|
||||
msgChan := make(chan []byte)
|
||||
|
||||
// Subscribing with topic, and with subtopic, so that we can publish messages.
|
||||
client, err := newClient(address, "clientID1", brokerTimeout)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
token := client.Subscribe(topic, qos, func(_ mqtt.Client, m mqtt.Message) {
|
||||
msgChan <- m.Payload()
|
||||
})
|
||||
if ok := token.WaitTimeout(tokenTimeout); !ok {
|
||||
assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", topic))
|
||||
}
|
||||
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
|
||||
|
||||
token = client.Subscribe(fmt.Sprintf("%s.%s", topic, subtopic), qos, func(_ mqtt.Client, m mqtt.Message) {
|
||||
msgChan <- m.Payload()
|
||||
})
|
||||
if ok := token.WaitTimeout(tokenTimeout); !ok {
|
||||
assert.Fail(t, fmt.Sprintf("failed to subscribe to topic %s", fmt.Sprintf("%s.%s", topic, subtopic)))
|
||||
}
|
||||
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
|
||||
|
||||
t.Cleanup(func() {
|
||||
token := client.Unsubscribe(topic, fmt.Sprintf("%s.%s", topic, subtopic))
|
||||
token.WaitTimeout(tokenTimeout)
|
||||
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
|
||||
|
||||
client.Disconnect(100)
|
||||
})
|
||||
|
||||
// Test publish with an empty topic.
|
||||
err = pubsub.Publish(context.TODO(), "", &messaging.Message{Payload: data})
|
||||
assert.Equal(t, err, mqttpubsub.ErrEmptyTopic, fmt.Sprintf("Publish with empty topic: expected: %s, got: %s", mqttpubsub.ErrEmptyTopic, err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
channel string
|
||||
subtopic string
|
||||
payload []byte
|
||||
}{
|
||||
{
|
||||
desc: "publish message with nil payload",
|
||||
payload: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish message with string payload",
|
||||
payload: data,
|
||||
},
|
||||
{
|
||||
desc: "publish message with channel",
|
||||
payload: data,
|
||||
channel: channel,
|
||||
},
|
||||
{
|
||||
desc: "publish message with subtopic",
|
||||
payload: data,
|
||||
subtopic: subtopic,
|
||||
},
|
||||
{
|
||||
desc: "publish message with channel and subtopic",
|
||||
payload: data,
|
||||
channel: channel,
|
||||
subtopic: subtopic,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
expectedMsg := messaging.Message{
|
||||
Publisher: "clientID11",
|
||||
Channel: tc.channel,
|
||||
Subtopic: tc.subtopic,
|
||||
Payload: tc.payload,
|
||||
}
|
||||
|
||||
err := pubsub.Publish(context.TODO(), topic, &expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err))
|
||||
|
||||
data, err := proto.Marshal(&expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
if tc.payload != nil {
|
||||
assert.Equal(t, expectedMsg.GetPayload(), receivedMsg, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, data, receivedMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
msgChan := make(chan *messaging.Message)
|
||||
|
||||
// Creating client to Publish messages to subscribed topic.
|
||||
client, err := newClient(address, "magistrala", brokerTimeout)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.Unsubscribe()
|
||||
client.Disconnect(100)
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: topic,
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: topic,
|
||||
clientID: "clientid2",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid2", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with an ID",
|
||||
topic: topic,
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid1",
|
||||
err: mqttpubsub.ErrEmptyTopic,
|
||||
handler: handler{false, "clientid1", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with empty id",
|
||||
topic: topic,
|
||||
clientID: "",
|
||||
err: mqttpubsub.ErrEmptyID,
|
||||
handler: handler{false, "", msgChan},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: tc.topic,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
err = pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
|
||||
|
||||
if tc.err == nil {
|
||||
expectedMsg := messaging.Message{
|
||||
Publisher: "clientID1",
|
||||
Channel: channel,
|
||||
Subtopic: subtopic,
|
||||
Payload: data,
|
||||
}
|
||||
data, err := proto.Marshal(&expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
|
||||
|
||||
token := client.Publish(tc.topic, qos, false, data)
|
||||
token.WaitTimeout(tokenTimeout)
|
||||
assert.Nil(t, token.Error(), fmt.Sprintf("got unexpected error: %s", token.Error()))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub(t *testing.T) {
|
||||
msgChan := make(chan *messaging.Message)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: topic,
|
||||
clientID: "clientid7",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid7", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: topic,
|
||||
clientID: "clientid8",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid8", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: "clientid7",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid7", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid7",
|
||||
err: mqttpubsub.ErrEmptyTopic,
|
||||
handler: handler{false, "clientid7", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with empty id",
|
||||
topic: topic,
|
||||
clientID: "",
|
||||
err: mqttpubsub.ErrEmptyID,
|
||||
handler: handler{false, "", msgChan},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: tc.topic,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
|
||||
|
||||
if tc.err == nil {
|
||||
// Use pubsub to subscribe to a topic, and then publish messages to that topic.
|
||||
expectedMsg := messaging.Message{
|
||||
Publisher: "clientID",
|
||||
Channel: channel,
|
||||
Subtopic: subtopic,
|
||||
Payload: data,
|
||||
}
|
||||
data, err := proto.Marshal(&expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: failed to serialize protobuf error: %s\n", tc.desc, err))
|
||||
|
||||
msg := messaging.Message{
|
||||
Payload: data,
|
||||
}
|
||||
// Publish message, and then receive it on message channel.
|
||||
err = pubsub.Publish(context.TODO(), topic, &msg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s\n", tc.desc, err))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
msgChan := make(chan *messaging.Message)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
subscribe bool // True for subscribe and false for unsubscribe.
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid9",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientid9", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from same topic with different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid9",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid9", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a non-existent topic with an ID",
|
||||
topic: "h",
|
||||
clientID: "clientid4",
|
||||
err: mqttpubsub.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: mqttpubsub.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd4",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientidd4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd4",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientidd4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientid4",
|
||||
err: mqttpubsub.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid4",
|
||||
err: mqttpubsub.ErrEmptyTopic,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with empty ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "",
|
||||
err: mqttpubsub.ErrEmptyID,
|
||||
subscribe: false,
|
||||
handler: handler{false, "", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a new topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
|
||||
clientID: "clientid55",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{true, "clientid5", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with an ID with failing handler",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
|
||||
clientID: "clientid55",
|
||||
err: errFailedHandleMessage,
|
||||
subscribe: false,
|
||||
handler: handler{true, "clientid5", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a new topic with subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
|
||||
clientID: "clientid55",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{true, "clientid5", msgChan},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with subtopic with an ID with failing handler",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
|
||||
clientID: "clientid55",
|
||||
err: errFailedHandleMessage,
|
||||
subscribe: false,
|
||||
handler: handler{true, "clientid5", msgChan},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: tc.topic,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
switch tc.subscribe {
|
||||
case true:
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
|
||||
default:
|
||||
err := pubsub.Unsubscribe(context.TODO(), tc.clientID, tc.topic)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
fail bool
|
||||
publisher string
|
||||
msgChan chan *messaging.Message
|
||||
}
|
||||
|
||||
func (h handler) Handle(msg *messaging.Message) error {
|
||||
if msg.GetPublisher() != h.publisher {
|
||||
h.msgChan <- msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h handler) Cancel() error {
|
||||
if h.fail {
|
||||
return errFailedHandleMessage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mqtt_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
mqttpubsub "github.com/absmach/magistrala/pkg/messaging/mqtt"
|
||||
mqtt "github.com/eclipse/paho.mqtt.golang"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
)
|
||||
|
||||
var (
|
||||
pubsub messaging.PubSub
|
||||
logger *slog.Logger
|
||||
address string
|
||||
)
|
||||
|
||||
const (
|
||||
username = "magistrala-mqtt"
|
||||
qos = 2
|
||||
port = "1883/tcp"
|
||||
brokerTimeout = 30 * time.Second
|
||||
poolMaxWait = 120 * time.Second
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "eclipse-mosquitto",
|
||||
Tag: "1.6.15",
|
||||
}, func(config *docker.HostConfig) {
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
address = fmt.Sprintf("%s:%s", "localhost", container.GetPort(port))
|
||||
pool.MaxWait = poolMaxWait
|
||||
|
||||
logger, err = mglog.New(os.Stdout, "debug")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
pubsub, err = mqttpubsub.NewPubSub(address, 2, brokerTimeout, logger)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
|
||||
defer func() {
|
||||
err = pubsub.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 2)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
func newClient(address, id string, timeout time.Duration) (mqtt.Client, error) {
|
||||
opts := mqtt.NewClientOptions().
|
||||
SetUsername(username).
|
||||
AddBroker(address).
|
||||
SetClientID(id)
|
||||
|
||||
client := mqtt.NewClient(opts)
|
||||
token := client.Connect()
|
||||
if token.Error() != nil {
|
||||
return nil, token.Error()
|
||||
}
|
||||
|
||||
ok := token.WaitTimeout(timeout)
|
||||
if !ok {
|
||||
return nil, mqttpubsub.ErrConnect
|
||||
}
|
||||
|
||||
if token.Error() != nil {
|
||||
return nil, token.Error()
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package nats hold the implementation of the Publisher and PubSub
|
||||
// interfaces for the NATS messaging system, the internal messaging
|
||||
// broker of the Magistrala IoT platform. Due to the practical requirements
|
||||
// implementation Publisher is created alongside PubSub. The reason for
|
||||
// this is that Subscriber implementation of NATS brings the burden of
|
||||
// additional struct fields which are not used by Publisher. Subscriber
|
||||
// is not implemented separately because PubSub can be used where Subscriber is needed.
|
||||
package nats
|
||||
@@ -1,56 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
)
|
||||
|
||||
// ErrInvalidType is returned when the provided value is not of the expected type.
|
||||
var ErrInvalidType = errors.New("invalid type")
|
||||
|
||||
// Prefix sets the prefix for the publisher.
|
||||
func Prefix(prefix string) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
p, ok := val.(*publisher)
|
||||
if !ok {
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
p.prefix = prefix
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// JSStream sets the JetStream for the publisher.
|
||||
func JSStream(stream jetstream.JetStream) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
p, ok := val.(*publisher)
|
||||
if !ok {
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
p.js = stream
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Stream sets the Stream for the subscriber.
|
||||
func Stream(stream jetstream.Stream) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
p, ok := val.(*pubsub)
|
||||
if !ok {
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
p.stream = stream
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// A maximum number of reconnect attempts before NATS connection closes permanently.
|
||||
// Value -1 represents an unlimited number of reconnect retries, i.e. the client
|
||||
// will never give up on retrying to re-establish connection to NATS server.
|
||||
maxReconnects = -1
|
||||
|
||||
// reconnectBufSize is obtained from the maximum number of unpublished events
|
||||
// multiplied by the approximate maximum size of a single event.
|
||||
reconnectBufSize = events.MaxUnpublishedEvents * (1024 * 1024)
|
||||
)
|
||||
|
||||
var _ messaging.Publisher = (*publisher)(nil)
|
||||
|
||||
type publisher struct {
|
||||
js jetstream.JetStream
|
||||
conn *broker.Conn
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewPublisher returns NATS message Publisher.
|
||||
func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) {
|
||||
conn, err := broker.Connect(url, broker.MaxReconnects(maxReconnects), broker.ReconnectBufSize(int(reconnectBufSize)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := jetstream.New(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := js.CreateStream(ctx, jsStreamConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &publisher{
|
||||
js: js,
|
||||
conn: conn,
|
||||
prefix: chansPrefix,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(ret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s.%s", pub.prefix, topic)
|
||||
if msg.GetSubtopic() != "" {
|
||||
subject = fmt.Sprintf("%s.%s", subject, msg.GetSubtopic())
|
||||
}
|
||||
|
||||
_, err = pub.js.Publish(ctx, subject, data)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (pub *publisher) Close() error {
|
||||
pub.conn.Close()
|
||||
return nil
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
broker "github.com/nats-io/nats.go"
|
||||
"github.com/nats-io/nats.go/jetstream"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const chansPrefix = "channels"
|
||||
|
||||
// Publisher and Subscriber errors.
|
||||
var (
|
||||
ErrNotSubscribed = errors.New("not subscribed")
|
||||
ErrEmptyTopic = errors.New("empty topic")
|
||||
ErrEmptyID = errors.New("empty id")
|
||||
|
||||
jsStreamConfig = jetstream.StreamConfig{
|
||||
Name: "channels",
|
||||
Description: "Magistrala stream for sending and receiving messages in between Magistrala channels",
|
||||
Subjects: []string{"channels.>"},
|
||||
Retention: jetstream.LimitsPolicy,
|
||||
MaxMsgsPerSubject: 1e6,
|
||||
MaxAge: time.Hour * 24,
|
||||
MaxMsgSize: 1024 * 1024,
|
||||
Discard: jetstream.DiscardOld,
|
||||
Storage: jetstream.FileStorage,
|
||||
}
|
||||
)
|
||||
|
||||
var _ messaging.PubSub = (*pubsub)(nil)
|
||||
|
||||
type pubsub struct {
|
||||
publisher
|
||||
logger *slog.Logger
|
||||
stream jetstream.Stream
|
||||
}
|
||||
|
||||
// NewPubSub returns NATS message publisher/subscriber.
|
||||
// Parameter queue specifies the queue for the Subscribe method.
|
||||
// If queue is specified (is not an empty string), Subscribe method
|
||||
// will execute NATS QueueSubscribe which is conceptually different
|
||||
// from ordinary subscribe. For more information, please take a look
|
||||
// here: https://docs.nats.io/developing-with-nats/receiving/queues.
|
||||
// If the queue is empty, Subscribe will be used.
|
||||
func NewPubSub(ctx context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
|
||||
conn, err := broker.Connect(url, broker.MaxReconnects(maxReconnects))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js, err := jetstream.New(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := js.CreateStream(ctx, jsStreamConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &pubsub{
|
||||
publisher: publisher{
|
||||
js: js,
|
||||
conn: conn,
|
||||
prefix: chansPrefix,
|
||||
},
|
||||
stream: stream,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(ret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
if cfg.ID == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if cfg.Topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
|
||||
nh := ps.natsHandler(cfg.Handler)
|
||||
|
||||
consumerConfig := jetstream.ConsumerConfig{
|
||||
Name: formatConsumerName(cfg.Topic, cfg.ID),
|
||||
Durable: formatConsumerName(cfg.Topic, cfg.ID),
|
||||
Description: fmt.Sprintf("Magistrala consumer of id %s for cfg.Topic %s", cfg.ID, cfg.Topic),
|
||||
DeliverPolicy: jetstream.DeliverNewPolicy,
|
||||
FilterSubject: cfg.Topic,
|
||||
}
|
||||
|
||||
switch cfg.DeliveryPolicy {
|
||||
case messaging.DeliverNewPolicy:
|
||||
consumerConfig.DeliverPolicy = jetstream.DeliverNewPolicy
|
||||
case messaging.DeliverAllPolicy:
|
||||
consumerConfig.DeliverPolicy = jetstream.DeliverAllPolicy
|
||||
}
|
||||
|
||||
consumer, err := ps.stream.CreateOrUpdateConsumer(ctx, consumerConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create consumer: %w", err)
|
||||
}
|
||||
|
||||
if _, err = consumer.Consume(nh); err != nil {
|
||||
return fmt.Errorf("failed to consume: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
if id == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
|
||||
err := ps.stream.DeleteConsumer(ctx, formatConsumerName(topic, id))
|
||||
switch {
|
||||
case errors.Is(err, jetstream.ErrConsumerNotFound):
|
||||
return ErrNotSubscribed
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *pubsub) natsHandler(h messaging.MessageHandler) func(m jetstream.Msg) {
|
||||
return func(m jetstream.Msg) {
|
||||
var msg messaging.Message
|
||||
if err := proto.Unmarshal(m.Data(), &msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Handle(&msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
|
||||
}
|
||||
if err := m.Ack(); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatConsumerName(topic, id string) string {
|
||||
// A durable name cannot contain whitespace, ., *, >, path separators (forward or backwards slash), and non-printable characters.
|
||||
chars := []string{
|
||||
" ", "_",
|
||||
".", "_",
|
||||
"*", "_",
|
||||
">", "_",
|
||||
"/", "_",
|
||||
"\\", "_",
|
||||
}
|
||||
topic = strings.NewReplacer(chars...).Replace(topic)
|
||||
|
||||
return fmt.Sprintf("%s-%s", topic, id)
|
||||
}
|
||||
@@ -1,297 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/nats"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
topic = "topic"
|
||||
chansPrefix = "channels"
|
||||
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
|
||||
subtopic = "engine"
|
||||
clientID = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
|
||||
)
|
||||
|
||||
var (
|
||||
msgChan = make(chan *messaging.Message)
|
||||
message = &messaging.Message{
|
||||
Channel: channel,
|
||||
Subtopic: subtopic,
|
||||
Publisher: "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b",
|
||||
Protocol: "mqtt",
|
||||
Payload: []byte("payload"),
|
||||
Created: time.Now().UnixNano(),
|
||||
}
|
||||
)
|
||||
|
||||
func TestPublisher(t *testing.T) {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: clientID,
|
||||
Topic: fmt.Sprintf("%s.>", chansPrefix),
|
||||
Handler: handler{},
|
||||
}
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
subtopic string
|
||||
message *messaging.Message
|
||||
error error
|
||||
}{
|
||||
{
|
||||
desc: "publish message with empty message",
|
||||
topic: channel,
|
||||
subtopic: subtopic,
|
||||
message: &messaging.Message{},
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish message with message",
|
||||
topic: channel,
|
||||
subtopic: subtopic,
|
||||
message: message,
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish message with topic and empty subtopic",
|
||||
topic: channel,
|
||||
subtopic: "",
|
||||
message: message,
|
||||
error: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish message with subtopic and empty topic",
|
||||
topic: "",
|
||||
subtopic: subtopic,
|
||||
message: message,
|
||||
error: nats.ErrEmptyTopic,
|
||||
},
|
||||
{
|
||||
desc: "publish message with topic and subtopic",
|
||||
topic: channel,
|
||||
subtopic: subtopic,
|
||||
message: message,
|
||||
error: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
tc.message.Subtopic = tc.subtopic
|
||||
err := pubsub.Publish(context.TODO(), tc.topic, tc.message)
|
||||
assert.Equal(t, tc.error, err, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.error, err))
|
||||
|
||||
if err == nil {
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, tc.message.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, tc.message.Payload, receivedMsg))
|
||||
assert.Equal(t, tc.message.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
assert.Equal(t, tc.message.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
assert.Equal(t, tc.message.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
assert.Equal(t, tc.message.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
assert.Equal(t, tc.message.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
assert.Equal(t, tc.message.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &tc.message, receivedMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubsub(t *testing.T) {
|
||||
// Test Subscribe and Unsubscribe.
|
||||
subcases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
errorMessage error
|
||||
pubsub bool // true for subscribe and false for unsubscribe.
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe using malformed topic and ID",
|
||||
topic: fmt.Sprintf("%s.>", chansPrefix),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe using malformed topic and ID",
|
||||
topic: fmt.Sprintf("%s.*", chansPrefix),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid2",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nil,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a non-existent topic with an ID",
|
||||
topic: "h",
|
||||
clientID: "clientid1",
|
||||
errorMessage: nats.ErrNotSubscribed,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from the same topic with a different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientidd2",
|
||||
errorMessage: nats.ErrNotSubscribed,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from the same topic with a different ID not subscribed",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientidd3",
|
||||
errorMessage: nats.ErrNotSubscribed,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nats.ErrNotSubscribed,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd1",
|
||||
errorMessage: nil,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd1",
|
||||
errorMessage: nil,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientid1",
|
||||
errorMessage: nats.ErrNotSubscribed,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid1",
|
||||
errorMessage: nats.ErrEmptyTopic,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid1",
|
||||
errorMessage: nats.ErrEmptyTopic,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with empty id",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "",
|
||||
errorMessage: nats.ErrEmptyID,
|
||||
pubsub: true,
|
||||
handler: handler{},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with empty id",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "",
|
||||
errorMessage: nats.ErrEmptyID,
|
||||
pubsub: false,
|
||||
handler: handler{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, pc := range subcases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: pc.clientID,
|
||||
Topic: pc.topic,
|
||||
Handler: pc.handler,
|
||||
}
|
||||
if pc.pubsub == true {
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
if pc.errorMessage == nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
|
||||
} else {
|
||||
assert.Equal(t, err, pc.errorMessage, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
|
||||
}
|
||||
} else {
|
||||
err := pubsub.Unsubscribe(context.TODO(), pc.clientID, pc.topic)
|
||||
if pc.errorMessage == nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
|
||||
} else {
|
||||
assert.Equal(t, err, pc.errorMessage, fmt.Sprintf("%s expected %+v got %+v\n", pc.desc, pc.errorMessage, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct{}
|
||||
|
||||
func (h handler) Handle(msg *messaging.Message) error {
|
||||
msgChan <- msg
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h handler) Cancel() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package nats_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/nats"
|
||||
"github.com/ory/dockertest/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
publisher messaging.Publisher
|
||||
pubsub messaging.PubSub
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "nats",
|
||||
Tag: "2.10.9-alpine",
|
||||
Cmd: []string{"-DVV", "-js"},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
address := fmt.Sprintf("nats://%s:%s", "localhost", container.GetPort("4222/tcp"))
|
||||
if err := pool.Retry(func() error {
|
||||
publisher, err = nats.NewPublisher(context.Background(), address)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
logger, err := mglog.New(os.Stdout, "error")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
if err := pool.Retry(func() error {
|
||||
pubsub, err = nats.NewPubSub(context.Background(), address, logger)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 2)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tracing provides tracing instrumentation for Magistrala things policies service.
|
||||
//
|
||||
// This package provides tracing middleware for Magistrala things policies service.
|
||||
// It can be used to trace incoming requests and add tracing capabilities to
|
||||
// Magistrala things policies service.
|
||||
//
|
||||
// For more details about tracing instrumentation for Magistrala messaging refer
|
||||
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
|
||||
package tracing
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Traced operations.
|
||||
const publishOP = "publish"
|
||||
|
||||
var defaultAttributes = []attribute.KeyValue{
|
||||
attribute.String("messaging.system", "nats"),
|
||||
attribute.String("network.protocol.name", "nats"),
|
||||
attribute.String("network.protocol.version", "2.2.4"),
|
||||
}
|
||||
|
||||
var _ messaging.Publisher = (*publisherMiddleware)(nil)
|
||||
|
||||
type publisherMiddleware struct {
|
||||
publisher messaging.Publisher
|
||||
tracer trace.Tracer
|
||||
host server.Config
|
||||
}
|
||||
|
||||
func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
|
||||
pub := &publisherMiddleware{
|
||||
publisher: publisher,
|
||||
tracer: tracer,
|
||||
host: config,
|
||||
}
|
||||
|
||||
return pub
|
||||
}
|
||||
|
||||
func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer)
|
||||
defer span.End()
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return pm.publisher.Publish(ctx, topic, msg)
|
||||
}
|
||||
|
||||
func (pm *publisherMiddleware) Close() error {
|
||||
return pm.publisher.Close()
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Constants to define different operations to be traced.
|
||||
const (
|
||||
subscribeOP = "receive"
|
||||
unsubscribeOp = "unsubscribe" // This is not specified in the open telemetry spec.
|
||||
processOp = "process"
|
||||
)
|
||||
|
||||
var _ messaging.PubSub = (*pubsubMiddleware)(nil)
|
||||
|
||||
type pubsubMiddleware struct {
|
||||
publisherMiddleware
|
||||
pubsub messaging.PubSub
|
||||
host server.Config
|
||||
}
|
||||
|
||||
// NewPubSub creates a new pubsub middleware that traces pubsub operations.
|
||||
func NewPubSub(config server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
|
||||
pb := &pubsubMiddleware{
|
||||
publisherMiddleware: publisherMiddleware{
|
||||
publisher: pubsub,
|
||||
tracer: tracer,
|
||||
host: config,
|
||||
},
|
||||
pubsub: pubsub,
|
||||
host: config,
|
||||
}
|
||||
|
||||
return pb
|
||||
}
|
||||
|
||||
// Subscribe creates a new subscription and traces the operation.
|
||||
func (pm *pubsubMiddleware) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, subscribeOP, cfg.ID, cfg.Topic, "", 0, pm.host, trace.SpanKindClient, pm.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
cfg.Handler = &traceHandler{
|
||||
ctx: ctx,
|
||||
handler: cfg.Handler,
|
||||
tracer: pm.tracer,
|
||||
host: pm.host,
|
||||
topic: cfg.Topic,
|
||||
clientID: cfg.ID,
|
||||
}
|
||||
|
||||
return pm.pubsub.Subscribe(ctx, cfg)
|
||||
}
|
||||
|
||||
// Unsubscribe removes an existing subscription and traces the operation.
|
||||
func (pm *pubsubMiddleware) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, unsubscribeOp, id, topic, "", 0, pm.host, trace.SpanKindInternal, pm.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return pm.pubsub.Unsubscribe(ctx, id, topic)
|
||||
}
|
||||
|
||||
// TraceHandler is used to trace the message handling operation.
|
||||
type traceHandler struct {
|
||||
ctx context.Context
|
||||
handler messaging.MessageHandler
|
||||
tracer trace.Tracer
|
||||
host server.Config
|
||||
topic string
|
||||
clientID string
|
||||
}
|
||||
|
||||
// Handle instruments the message handling operation.
|
||||
func (h *traceHandler) Handle(msg *messaging.Message) error {
|
||||
_, span := tracing.CreateSpan(h.ctx, processOp, h.clientID, h.topic, msg.GetSubtopic(), len(msg.GetPayload()), h.host, trace.SpanKindConsumer, h.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return h.handler.Handle(msg)
|
||||
}
|
||||
|
||||
// Cancel cancels the message handling operation.
|
||||
func (h *traceHandler) Cancel() error {
|
||||
return h.handler.Cancel()
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package messaging
|
||||
|
||||
import "context"
|
||||
|
||||
type DeliveryPolicy uint8
|
||||
|
||||
const (
|
||||
// DeliverNewPolicy will only deliver new messages that are sent after the consumer is created.
|
||||
// This is the default policy.
|
||||
DeliverNewPolicy DeliveryPolicy = iota
|
||||
|
||||
// DeliverAllPolicy starts delivering messages from the very beginning of a stream.
|
||||
DeliverAllPolicy
|
||||
)
|
||||
|
||||
// Publisher specifies message publishing API.
|
||||
type Publisher interface {
|
||||
// Publishes message to the stream.
|
||||
Publish(ctx context.Context, topic string, msg *Message) error
|
||||
|
||||
// Close gracefully closes message publisher's connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// MessageHandler represents Message handler for Subscriber.
|
||||
type MessageHandler interface {
|
||||
// Handle handles messages passed by underlying implementation.
|
||||
Handle(msg *Message) error
|
||||
|
||||
// Cancel is used for cleanup during unsubscribing and it's optional.
|
||||
Cancel() error
|
||||
}
|
||||
|
||||
type SubscriberConfig struct {
|
||||
ID string
|
||||
Topic string
|
||||
Handler MessageHandler
|
||||
DeliveryPolicy DeliveryPolicy
|
||||
}
|
||||
|
||||
// Subscriber specifies message subscription API.
|
||||
type Subscriber interface {
|
||||
// Subscribe subscribes to the message stream and consumes messages.
|
||||
Subscribe(ctx context.Context, cfg SubscriberConfig) error
|
||||
|
||||
// Unsubscribe unsubscribes from the message stream and
|
||||
// stops consuming messages.
|
||||
Unsubscribe(ctx context.Context, id, topic string) error
|
||||
|
||||
// Close gracefully closes message subscriber's connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// PubSub represents aggregation interface for publisher and subscriber.
|
||||
//
|
||||
//go:generate mockery --name PubSub --filename pubsub.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type PubSub interface {
|
||||
Publisher
|
||||
Subscriber
|
||||
}
|
||||
|
||||
// Option represents optional configuration for message broker.
|
||||
//
|
||||
// This is used to provide optional configuration parameters to the
|
||||
// underlying publisher and pubsub implementation so that it can be
|
||||
// configured to meet the specific needs.
|
||||
//
|
||||
// For example, it can be used to set the message prefix so that
|
||||
// brokers can be used for event sourcing as well as internal message broker.
|
||||
// Using value of type interface is not recommended but is the most suitable
|
||||
// for this use case as options should be compiled with respect to the
|
||||
// underlying broker which can either be RabbitMQ or NATS.
|
||||
//
|
||||
// The example below shows how to set the prefix and jetstream stream for NATS.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// broker.NewPublisher(ctx, url, broker.Prefix(eventsPrefix), broker.JSStream(js))
|
||||
type Option func(vals interface{}) error
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package rabbitmq holds the implementation of the Publisher and PubSub
|
||||
// interfaces for the RabbitMQ messaging system, the internal messaging
|
||||
// broker of the Magistrala IoT platform. Due to the practical requirements
|
||||
// implementation Publisher is created alongside PubSub. The reason for
|
||||
// this is that Subscriber implementation of RabbitMQ brings the burden of
|
||||
// additional struct fields which are not used by Publisher. Subscriber
|
||||
// is not implemented separately because PubSub can be used where Subscriber is needed.
|
||||
package rabbitmq
|
||||
@@ -1,60 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
)
|
||||
|
||||
// ErrInvalidType is returned when the provided value is not of the expected type.
|
||||
var ErrInvalidType = errors.New("invalid type")
|
||||
|
||||
// Prefix sets the prefix for the publisher.
|
||||
func Prefix(prefix string) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
p, ok := val.(*publisher)
|
||||
if !ok {
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
p.prefix = prefix
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Channel sets the channel for the publisher or subscriber.
|
||||
func Channel(channel *amqp.Channel) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
switch v := val.(type) {
|
||||
case *publisher:
|
||||
v.channel = channel
|
||||
case *pubsub:
|
||||
v.channel = channel
|
||||
default:
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Exchange sets the exchange for the publisher or subscriber.
|
||||
func Exchange(exchange string) messaging.Option {
|
||||
return func(val interface{}) error {
|
||||
switch v := val.(type) {
|
||||
case *publisher:
|
||||
v.exchange = exchange
|
||||
case *pubsub:
|
||||
v.exchange = exchange
|
||||
default:
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var _ messaging.Publisher = (*publisher)(nil)
|
||||
|
||||
type publisher struct {
|
||||
conn *amqp.Connection
|
||||
channel *amqp.Channel
|
||||
prefix string
|
||||
exchange string
|
||||
}
|
||||
|
||||
// NewPublisher returns RabbitMQ message Publisher.
|
||||
func NewPublisher(url string, opts ...messaging.Option) (messaging.Publisher, error) {
|
||||
conn, err := amqp.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, err := conn.Channel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &publisher{
|
||||
conn: conn,
|
||||
channel: ch,
|
||||
prefix: chansPrefix,
|
||||
exchange: exchangeName,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(ret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
data, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s.%s", pub.prefix, topic)
|
||||
if msg.GetSubtopic() != "" {
|
||||
subject = fmt.Sprintf("%s.%s", subject, msg.GetSubtopic())
|
||||
}
|
||||
subject = formatTopic(subject)
|
||||
|
||||
err = pub.channel.PublishWithContext(
|
||||
ctx,
|
||||
pub.exchange,
|
||||
subject,
|
||||
false,
|
||||
false,
|
||||
amqp.Publishing{
|
||||
Headers: amqp.Table{},
|
||||
ContentType: "application/octet-stream",
|
||||
AppId: "magistrala-publisher",
|
||||
Body: data,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pub *publisher) Close() error {
|
||||
return pub.conn.Close()
|
||||
}
|
||||
|
||||
func formatTopic(topic string) string {
|
||||
return strings.ReplaceAll(topic, ">", "#")
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// SubjectAllChannels represents subject to subscribe for all the channels.
|
||||
SubjectAllChannels = "channels.#"
|
||||
|
||||
exchangeName = "messages"
|
||||
chansPrefix = "channels"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotSubscribed indicates that the topic is not subscribed to.
|
||||
ErrNotSubscribed = errors.New("not subscribed")
|
||||
|
||||
// ErrEmptyTopic indicates the absence of topic.
|
||||
ErrEmptyTopic = errors.New("empty topic")
|
||||
|
||||
// ErrEmptyID indicates the absence of ID.
|
||||
ErrEmptyID = errors.New("empty ID")
|
||||
)
|
||||
var _ messaging.PubSub = (*pubsub)(nil)
|
||||
|
||||
type subscription struct {
|
||||
cancel func() error
|
||||
}
|
||||
type pubsub struct {
|
||||
publisher
|
||||
logger *slog.Logger
|
||||
subscriptions map[string]map[string]subscription
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewPubSub returns RabbitMQ message publisher/subscriber.
|
||||
func NewPubSub(url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) {
|
||||
conn, err := amqp.Dial(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch, err := conn.Channel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &pubsub{
|
||||
publisher: publisher{
|
||||
conn: conn,
|
||||
channel: ch,
|
||||
exchange: exchangeName,
|
||||
prefix: chansPrefix,
|
||||
},
|
||||
logger: logger,
|
||||
subscriptions: make(map[string]map[string]subscription),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(ret); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
if cfg.ID == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if cfg.Topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
ps.mu.Lock()
|
||||
|
||||
cfg.Topic = formatTopic(cfg.Topic)
|
||||
// Check topic
|
||||
s, ok := ps.subscriptions[cfg.Topic]
|
||||
if ok {
|
||||
// Check client ID
|
||||
if _, ok := s[cfg.ID]; ok {
|
||||
// Unlocking, so that Unsubscribe() can access ps.subscriptions
|
||||
ps.mu.Unlock()
|
||||
if err := ps.Unsubscribe(ctx, cfg.ID, cfg.Topic); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ps.mu.Lock()
|
||||
// value of s can be changed while ps.mu is unlocked
|
||||
s = ps.subscriptions[cfg.Topic]
|
||||
}
|
||||
}
|
||||
defer ps.mu.Unlock()
|
||||
if s == nil {
|
||||
s = make(map[string]subscription)
|
||||
ps.subscriptions[cfg.Topic] = s
|
||||
}
|
||||
|
||||
clientID := fmt.Sprintf("%s-%s", cfg.Topic, cfg.ID)
|
||||
|
||||
queue, err := ps.channel.QueueDeclare(clientID, true, false, false, false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ps.channel.QueueBind(queue.Name, cfg.Topic, ps.exchange, false, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msgs, err := ps.channel.Consume(queue.Name, clientID, true, false, false, false, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go ps.handle(msgs, cfg.Handler)
|
||||
s[cfg.ID] = subscription{
|
||||
cancel: func() error {
|
||||
if err := ps.channel.Cancel(clientID, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return cfg.Handler.Cancel()
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
if id == "" {
|
||||
return ErrEmptyID
|
||||
}
|
||||
if topic == "" {
|
||||
return ErrEmptyTopic
|
||||
}
|
||||
ps.mu.Lock()
|
||||
defer ps.mu.Unlock()
|
||||
|
||||
topic = formatTopic(topic)
|
||||
// Check topic
|
||||
s, ok := ps.subscriptions[topic]
|
||||
if !ok {
|
||||
return ErrNotSubscribed
|
||||
}
|
||||
// Check topic ID
|
||||
current, ok := s[id]
|
||||
if !ok {
|
||||
return ErrNotSubscribed
|
||||
}
|
||||
if current.cancel != nil {
|
||||
if err := current.cancel(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := ps.channel.QueueUnbind(topic, topic, exchangeName, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(s, id)
|
||||
if len(s) == 0 {
|
||||
delete(ps.subscriptions, topic)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *pubsub) handle(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) {
|
||||
for d := range deliveries {
|
||||
var msg messaging.Message
|
||||
if err := proto.Unmarshal(d.Body, &msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
|
||||
return
|
||||
}
|
||||
if err := h.Handle(&msg); err != nil {
|
||||
ps.logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,460 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
topic = "topic"
|
||||
chansPrefix = "channels"
|
||||
channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
|
||||
subtopic = "engine"
|
||||
clientID = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b"
|
||||
exchangeName = "messages"
|
||||
)
|
||||
|
||||
var (
|
||||
msgChan = make(chan *messaging.Message)
|
||||
data = []byte("payload")
|
||||
)
|
||||
|
||||
var errFailedHandleMessage = errors.New("failed to handle magistrala message")
|
||||
|
||||
func TestPublisher(t *testing.T) {
|
||||
// Subscribing with topic, and with subtopic, so that we can publish messages.
|
||||
conn, ch, err := newConn()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
topicChan := subscribe(t, ch, fmt.Sprintf("%s.%s", chansPrefix, topic))
|
||||
subtopicChan := subscribe(t, ch, fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic))
|
||||
|
||||
go rabbitHandler(topicChan, handler{})
|
||||
go rabbitHandler(subtopicChan, handler{})
|
||||
|
||||
t.Cleanup(func() {
|
||||
conn.Close()
|
||||
ch.Close()
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
channel string
|
||||
subtopic string
|
||||
payload []byte
|
||||
}{
|
||||
{
|
||||
desc: "publish message with nil payload",
|
||||
payload: nil,
|
||||
},
|
||||
{
|
||||
desc: "publish message with string payload",
|
||||
payload: data,
|
||||
},
|
||||
{
|
||||
desc: "publish message with channel",
|
||||
payload: data,
|
||||
channel: channel,
|
||||
},
|
||||
{
|
||||
desc: "publish message with subtopic",
|
||||
payload: data,
|
||||
subtopic: subtopic,
|
||||
},
|
||||
{
|
||||
desc: "publish message with channel and subtopic",
|
||||
payload: data,
|
||||
channel: channel,
|
||||
subtopic: subtopic,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
expectedMsg := messaging.Message{
|
||||
Publisher: clientID,
|
||||
Channel: tc.channel,
|
||||
Subtopic: tc.subtopic,
|
||||
Payload: tc.payload,
|
||||
}
|
||||
err = pubsub.Publish(context.TODO(), topic, &expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s", tc.desc, err))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
// Creating rabbitmq connection and channel, so that we can publish messages.
|
||||
conn, ch, err := newConn()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
t.Cleanup(func() {
|
||||
conn.Close()
|
||||
ch.Close()
|
||||
})
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: topic,
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: topic,
|
||||
clientID: "clientid2",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid2"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with an ID",
|
||||
topic: topic,
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an already subscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: "clientid1",
|
||||
err: nil,
|
||||
handler: handler{false, "clientid1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid1",
|
||||
err: rabbitmq.ErrEmptyTopic,
|
||||
handler: handler{false, "clientid1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with empty id",
|
||||
topic: topic,
|
||||
clientID: "",
|
||||
err: rabbitmq.ErrEmptyID,
|
||||
handler: handler{false, ""},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: tc.topic,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
|
||||
|
||||
if tc.err == nil {
|
||||
expectedMsg := messaging.Message{
|
||||
Publisher: "CLIENTID",
|
||||
Channel: channel,
|
||||
Subtopic: subtopic,
|
||||
Payload: data,
|
||||
}
|
||||
|
||||
data, err := proto.Marshal(&expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
err = ch.PublishWithContext(
|
||||
context.Background(),
|
||||
exchangeName,
|
||||
tc.topic,
|
||||
false,
|
||||
false,
|
||||
amqp.Publishing{
|
||||
Headers: amqp.Table{},
|
||||
ContentType: "application/octet-stream",
|
||||
AppId: "magistrala-publisher",
|
||||
Body: data,
|
||||
})
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsubscribe(t *testing.T) {
|
||||
// Test Subscribe and Unsubscribe
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
subscribe bool // True for subscribe and false for unsubscribe.
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid9",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientid9"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from same topic with different ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid9",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid9"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a non-existent topic with an ID",
|
||||
topic: "h",
|
||||
clientID: "clientid4",
|
||||
err: rabbitmq.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "clientid4",
|
||||
err: rabbitmq.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd4",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{false, "clientidd4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientidd4",
|
||||
err: nil,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientidd4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic, subtopic),
|
||||
clientID: "clientid4",
|
||||
err: rabbitmq.ErrNotSubscribed,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: "clientid4",
|
||||
err: rabbitmq.ErrEmptyTopic,
|
||||
subscribe: false,
|
||||
handler: handler{false, "clientid4"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with empty ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic),
|
||||
clientID: "",
|
||||
err: rabbitmq.ErrEmptyID,
|
||||
subscribe: false,
|
||||
handler: handler{false, ""},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a new topic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
|
||||
clientID: "clientid55",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{true, "clientid5"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with an ID with failing handler",
|
||||
topic: fmt.Sprintf("%s.%s", chansPrefix, topic+"2"),
|
||||
clientID: "clientid55",
|
||||
err: errFailedHandleMessage,
|
||||
subscribe: false,
|
||||
handler: handler{true, "clientid5"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a new topic with subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
|
||||
clientID: "clientid55",
|
||||
err: nil,
|
||||
subscribe: true,
|
||||
handler: handler{true, "clientid5"},
|
||||
},
|
||||
{
|
||||
desc: "Unsubscribe from a topic with subtopic with an ID with failing handler",
|
||||
topic: fmt.Sprintf("%s.%s.%s", chansPrefix, topic+"2", subtopic),
|
||||
clientID: "clientid55",
|
||||
err: errFailedHandleMessage,
|
||||
subscribe: false,
|
||||
handler: handler{true, "clientid5"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: tc.topic,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
switch tc.subscribe {
|
||||
case true:
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
|
||||
default:
|
||||
err := pubsub.Unsubscribe(context.TODO(), tc.clientID, tc.topic)
|
||||
assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
topic string
|
||||
clientID string
|
||||
err error
|
||||
handler messaging.MessageHandler
|
||||
}{
|
||||
{
|
||||
desc: "Subscribe to a topic with an ID",
|
||||
topic: topic,
|
||||
clientID: clientID,
|
||||
err: nil,
|
||||
handler: handler{false, clientID},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to the same topic with a different ID",
|
||||
topic: topic,
|
||||
clientID: clientID + "1",
|
||||
err: nil,
|
||||
handler: handler{false, clientID + "1"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with a subtopic with an ID",
|
||||
topic: fmt.Sprintf("%s.%s", topic, subtopic),
|
||||
clientID: clientID + "2",
|
||||
err: nil,
|
||||
handler: handler{false, clientID + "2"},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to an empty topic with an ID",
|
||||
topic: "",
|
||||
clientID: clientID,
|
||||
err: rabbitmq.ErrEmptyTopic,
|
||||
handler: handler{false, clientID},
|
||||
},
|
||||
{
|
||||
desc: "Subscribe to a topic with empty id",
|
||||
topic: topic,
|
||||
clientID: "",
|
||||
err: rabbitmq.ErrEmptyID,
|
||||
handler: handler{false, ""},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
subject := ""
|
||||
if tc.topic != "" {
|
||||
subject = fmt.Sprintf("%s.%s", chansPrefix, tc.topic)
|
||||
}
|
||||
subCfg := messaging.SubscriberConfig{
|
||||
ID: tc.clientID,
|
||||
Topic: subject,
|
||||
Handler: tc.handler,
|
||||
}
|
||||
err := pubsub.Subscribe(context.TODO(), subCfg)
|
||||
|
||||
switch tc.err {
|
||||
case nil:
|
||||
// If no error, publish message, and receive after subscribing.
|
||||
expectedMsg := messaging.Message{
|
||||
Channel: channel,
|
||||
Payload: data,
|
||||
}
|
||||
|
||||
err = pubsub.Publish(context.TODO(), tc.topic, &expectedMsg)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err))
|
||||
|
||||
receivedMsg := <-msgChan
|
||||
assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg))
|
||||
|
||||
err = pubsub.Unsubscribe(context.TODO(), tc.clientID, fmt.Sprintf("%s.%s", chansPrefix, tc.topic))
|
||||
assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err))
|
||||
default:
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
fail bool
|
||||
publisher string
|
||||
}
|
||||
|
||||
func (h handler) Handle(msg *messaging.Message) error {
|
||||
if msg.GetPublisher() != h.publisher {
|
||||
msgChan <- msg
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h handler) Cancel() error {
|
||||
if h.fail {
|
||||
return errFailedHandleMessage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package rabbitmq_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/rabbitmq"
|
||||
"github.com/ory/dockertest/v3"
|
||||
amqp "github.com/rabbitmq/amqp091-go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
port = "5672/tcp"
|
||||
brokerName = "rabbitmq"
|
||||
brokerVersion = "3.12.12-alpine"
|
||||
)
|
||||
|
||||
var (
|
||||
publisher messaging.Publisher
|
||||
pubsub messaging.PubSub
|
||||
logger *slog.Logger
|
||||
address string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err := pool.Run(brokerName, brokerVersion, []string{})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
handleInterrupt(pool, container)
|
||||
|
||||
address = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort(port))
|
||||
if err := pool.Retry(func() error {
|
||||
publisher, err = rabbitmq.NewPublisher(address)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
logger, err = mglog.New(os.Stdout, "debug")
|
||||
if err != nil {
|
||||
log.Fatal(err.Error())
|
||||
}
|
||||
if err := pool.Retry(func() error {
|
||||
pubsub, err = rabbitmq.NewPubSub(address, logger)
|
||||
return err
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func newConn() (*amqp.Connection, *amqp.Channel, error) {
|
||||
conn, err := amqp.Dial(address)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
ch, err := conn.Channel()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return conn, ch, nil
|
||||
}
|
||||
|
||||
func rabbitHandler(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) {
|
||||
for d := range deliveries {
|
||||
var msg messaging.Message
|
||||
if err := proto.Unmarshal(d.Body, &msg); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err))
|
||||
return
|
||||
}
|
||||
if err := h.Handle(&msg); err != nil {
|
||||
logger.Warn(fmt.Sprintf("Failed to handle Magistrala message: %s", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func subscribe(t *testing.T, ch *amqp.Channel, topic string) <-chan amqp.Delivery {
|
||||
_, err := ch.QueueDeclare(topic, true, true, true, false, nil)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
err = ch.QueueBind(topic, topic, exchangeName, false, nil)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
clientID := fmt.Sprintf("%s-%s", topic, clientID)
|
||||
msgs, err := ch.Consume(topic, clientID, true, false, false, false, nil)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
return msgs
|
||||
}
|
||||
|
||||
func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) {
|
||||
c := make(chan os.Signal, 2)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-c
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tracing provides tracing instrumentation for Magistrala things policies service.
|
||||
//
|
||||
// This package provides tracing middleware for Magistrala things policies service.
|
||||
// It can be used to trace incoming requests and add tracing capabilities to
|
||||
// Magistrala things policies service.
|
||||
//
|
||||
// For more details about tracing instrumentation for Magistrala messaging refer
|
||||
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
|
||||
package tracing
|
||||
@@ -1,54 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Traced operations.
|
||||
const publishOP = "publish"
|
||||
|
||||
var defaultAttributes = []attribute.KeyValue{
|
||||
attribute.String("messaging.system", "rabbitmq"),
|
||||
attribute.String("network.protocol.name", "amqp"),
|
||||
attribute.String("network.protocol.version", "3.9.20"),
|
||||
attribute.String("messaging.rabbitmq.destination.routing_key", "magistrala"),
|
||||
}
|
||||
|
||||
var _ messaging.Publisher = (*publisherMiddleware)(nil)
|
||||
|
||||
type publisherMiddleware struct {
|
||||
publisher messaging.Publisher
|
||||
tracer trace.Tracer
|
||||
host server.Config
|
||||
}
|
||||
|
||||
func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher {
|
||||
pub := &publisherMiddleware{
|
||||
publisher: publisher,
|
||||
tracer: tracer,
|
||||
host: config,
|
||||
}
|
||||
|
||||
return pub
|
||||
}
|
||||
|
||||
func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return pm.publisher.Publish(ctx, topic, msg)
|
||||
}
|
||||
|
||||
func (pm *publisherMiddleware) Close() error {
|
||||
return pm.publisher.Close()
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/messaging"
|
||||
"github.com/absmach/magistrala/pkg/messaging/tracing"
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// Constants to define different operations to be traced.
|
||||
const (
|
||||
subscribeOP = "receive"
|
||||
unsubscribeOp = "unsubscribe" // This is not specified in the open telemetry spec.
|
||||
processOp = "process"
|
||||
)
|
||||
|
||||
var _ messaging.PubSub = (*pubsubMiddleware)(nil)
|
||||
|
||||
type pubsubMiddleware struct {
|
||||
publisherMiddleware
|
||||
pubsub messaging.PubSub
|
||||
host server.Config
|
||||
}
|
||||
|
||||
// NewPubSub creates a new pubsub middleware that traces pubsub operations.
|
||||
func NewPubSub(config server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub {
|
||||
pb := &pubsubMiddleware{
|
||||
publisherMiddleware: publisherMiddleware{
|
||||
publisher: pubsub,
|
||||
tracer: tracer,
|
||||
host: config,
|
||||
},
|
||||
pubsub: pubsub,
|
||||
host: config,
|
||||
}
|
||||
|
||||
return pb
|
||||
}
|
||||
|
||||
// Subscribe creates a new subscription and traces the operation.
|
||||
func (pm *pubsubMiddleware) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, subscribeOP, cfg.ID, cfg.Topic, "", 0, pm.host, trace.SpanKindClient, pm.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
cfg.Handler = &traceHandler{
|
||||
ctx: ctx,
|
||||
handler: cfg.Handler,
|
||||
tracer: pm.tracer,
|
||||
host: pm.host,
|
||||
topic: cfg.Topic,
|
||||
clientID: cfg.ID,
|
||||
}
|
||||
|
||||
return pm.pubsub.Subscribe(ctx, cfg)
|
||||
}
|
||||
|
||||
// Unsubscribe removes an existing subscription and traces the operation.
|
||||
func (pm *pubsubMiddleware) Unsubscribe(ctx context.Context, id, topic string) error {
|
||||
ctx, span := tracing.CreateSpan(ctx, unsubscribeOp, id, topic, "", 0, pm.host, trace.SpanKindInternal, pm.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return pm.pubsub.Unsubscribe(ctx, id, topic)
|
||||
}
|
||||
|
||||
// TraceHandler is used to trace the message handling operation.
|
||||
type traceHandler struct {
|
||||
ctx context.Context
|
||||
handler messaging.MessageHandler
|
||||
tracer trace.Tracer
|
||||
host server.Config
|
||||
topic string
|
||||
clientID string
|
||||
}
|
||||
|
||||
// Handle instruments the message handling operation.
|
||||
func (h *traceHandler) Handle(msg *messaging.Message) error {
|
||||
_, span := tracing.CreateSpan(h.ctx, processOp, h.clientID, h.topic, msg.GetSubtopic(), len(msg.GetPayload()), h.host, trace.SpanKindConsumer, h.tracer)
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(defaultAttributes...)
|
||||
|
||||
return h.handler.Handle(msg)
|
||||
}
|
||||
|
||||
// Cancel cancels the message handling operation.
|
||||
func (h *traceHandler) Cancel() error {
|
||||
return h.handler.Cancel()
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tracing provides tracing instrumentation for Magistrala things policies service.
|
||||
//
|
||||
// This package provides tracing middleware for Magistrala things policies service.
|
||||
// It can be used to trace incoming requests and add tracing capabilities to
|
||||
// Magistrala things policies service.
|
||||
//
|
||||
// For more details about tracing instrumentation for Magistrala messaging refer
|
||||
// to the documentation at https://docs.magistrala.abstractmachines.fr/tracing/.
|
||||
package tracing
|
||||
@@ -1,44 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/server"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
var defaultAttributes = []attribute.KeyValue{
|
||||
attribute.Bool("messaging.destination.anonymous", false),
|
||||
attribute.String("messaging.destination.template", "channels/{channelID}/messages/*"),
|
||||
attribute.Bool("messaging.destination.temporary", true),
|
||||
attribute.String("network.transport", "tcp"),
|
||||
attribute.String("network.type", "ipv4"),
|
||||
}
|
||||
|
||||
func CreateSpan(ctx context.Context, operation, clientID, topic, subTopic string, msgSize int, cfg server.Config, spanKind trace.SpanKind, tracer trace.Tracer) (context.Context, trace.Span) {
|
||||
subject := fmt.Sprintf("channels.%s.messages", topic)
|
||||
if subTopic != "" {
|
||||
subject = fmt.Sprintf("%s.%s", subject, subTopic)
|
||||
}
|
||||
spanName := fmt.Sprintf("%s %s", subject, operation)
|
||||
|
||||
kvOpts := []attribute.KeyValue{
|
||||
attribute.String("messaging.operation", operation),
|
||||
attribute.String("messaging.client_id", clientID),
|
||||
attribute.String("messaging.destination.name", subject),
|
||||
attribute.String("server.address", cfg.Host),
|
||||
attribute.String("server.socket.port", cfg.Port),
|
||||
}
|
||||
|
||||
if msgSize > 0 {
|
||||
kvOpts = append(kvOpts, attribute.Int("messaging.message.payload_size_bytes", msgSize))
|
||||
}
|
||||
|
||||
kvOpts = append(kvOpts, defaultAttributes...)
|
||||
|
||||
return tracer.Start(ctx, spanName, trace.WithAttributes(kvOpts...), trace.WithSpanKind(spanKind))
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package oauth2 contains the domain concept definitions needed to support
|
||||
// Magistrala ui service OAuth2 functionality.
|
||||
package oauth2
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package google contains the domain concept definitions needed to support
|
||||
// Magistrala services for Google OAuth2 functionality.
|
||||
package google
|
||||
@@ -1,132 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
mgoauth2 "github.com/absmach/magistrala/pkg/oauth2"
|
||||
uclient "github.com/absmach/magistrala/users"
|
||||
"golang.org/x/oauth2"
|
||||
googleoauth2 "golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const (
|
||||
providerName = "google"
|
||||
defTimeout = 1 * time.Minute
|
||||
userInfoURL = "https://www.googleapis.com/oauth2/v2/userinfo?access_token="
|
||||
tokenInfoURL = "https://oauth2.googleapis.com/tokeninfo?access_token="
|
||||
)
|
||||
|
||||
var scopes = []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
var _ mgoauth2.Provider = (*config)(nil)
|
||||
|
||||
type config struct {
|
||||
config *oauth2.Config
|
||||
state string
|
||||
uiRedirectURL string
|
||||
errorURL string
|
||||
}
|
||||
|
||||
// NewProvider returns a new Google OAuth provider.
|
||||
func NewProvider(cfg mgoauth2.Config, uiRedirectURL, errorURL string) mgoauth2.Provider {
|
||||
return &config{
|
||||
config: &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
Endpoint: googleoauth2.Endpoint,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Scopes: scopes,
|
||||
},
|
||||
state: cfg.State,
|
||||
uiRedirectURL: uiRedirectURL,
|
||||
errorURL: errorURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *config) Name() string {
|
||||
return providerName
|
||||
}
|
||||
|
||||
func (cfg *config) State() string {
|
||||
return cfg.state
|
||||
}
|
||||
|
||||
func (cfg *config) RedirectURL() string {
|
||||
return cfg.uiRedirectURL
|
||||
}
|
||||
|
||||
func (cfg *config) ErrorURL() string {
|
||||
return cfg.errorURL
|
||||
}
|
||||
|
||||
func (cfg *config) IsEnabled() bool {
|
||||
return cfg.config.ClientID != "" && cfg.config.ClientSecret != ""
|
||||
}
|
||||
|
||||
func (cfg *config) Exchange(ctx context.Context, code string) (oauth2.Token, error) {
|
||||
token, err := cfg.config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return oauth2.Token{}, err
|
||||
}
|
||||
|
||||
return *token, nil
|
||||
}
|
||||
|
||||
func (cfg *config) UserInfo(accessToken string) (uclient.User, error) {
|
||||
resp, err := http.Get(userInfoURL + url.QueryEscape(accessToken))
|
||||
if err != nil {
|
||||
return uclient.User{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return uclient.User{}, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return uclient.User{}, err
|
||||
}
|
||||
|
||||
var user struct {
|
||||
ID string `json:"id"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &user); err != nil {
|
||||
return uclient.User{}, err
|
||||
}
|
||||
|
||||
if user.ID == "" || user.FirstName == "" || user.LastName == "" || user.Email == "" {
|
||||
return uclient.User{}, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
client := uclient.User{
|
||||
ID: user.ID,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
Email: user.Email,
|
||||
Metadata: map[string]interface{}{
|
||||
"oauth_provider": providerName,
|
||||
"profile_picture": user.Picture,
|
||||
},
|
||||
Status: uclient.EnabledStatus,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -1,180 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
users "github.com/absmach/magistrala/users"
|
||||
|
||||
xoauth2 "golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Provider is an autogenerated mock type for the Provider type
|
||||
type Provider struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// ErrorURL provides a mock function with given fields:
|
||||
func (_m *Provider) ErrorURL() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ErrorURL")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Exchange provides a mock function with given fields: ctx, code
|
||||
func (_m *Provider) Exchange(ctx context.Context, code string) (xoauth2.Token, error) {
|
||||
ret := _m.Called(ctx, code)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Exchange")
|
||||
}
|
||||
|
||||
var r0 xoauth2.Token
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (xoauth2.Token, error)); ok {
|
||||
return rf(ctx, code)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) xoauth2.Token); ok {
|
||||
r0 = rf(ctx, code)
|
||||
} else {
|
||||
r0 = ret.Get(0).(xoauth2.Token)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, code)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// IsEnabled provides a mock function with given fields:
|
||||
func (_m *Provider) IsEnabled() bool {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IsEnabled")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
if rf, ok := ret.Get(0).(func() bool); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Name provides a mock function with given fields:
|
||||
func (_m *Provider) Name() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Name")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// RedirectURL provides a mock function with given fields:
|
||||
func (_m *Provider) RedirectURL() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RedirectURL")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// State provides a mock function with given fields:
|
||||
func (_m *Provider) State() string {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for State")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
if rf, ok := ret.Get(0).(func() string); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// UserInfo provides a mock function with given fields: accessToken
|
||||
func (_m *Provider) UserInfo(accessToken string) (users.User, error) {
|
||||
ret := _m.Called(accessToken)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UserInfo")
|
||||
}
|
||||
|
||||
var r0 users.User
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(string) (users.User, error)); ok {
|
||||
return rf(accessToken)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(string) users.User); ok {
|
||||
r0 = rf(accessToken)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.User)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = rf(accessToken)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewProvider creates a new instance of Provider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewProvider(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Provider {
|
||||
mock := &Provider{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/users"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Config is the configuration for the OAuth2 provider.
|
||||
type Config struct {
|
||||
ClientID string `env:"CLIENT_ID" envDefault:""`
|
||||
ClientSecret string `env:"CLIENT_SECRET" envDefault:""`
|
||||
State string `env:"STATE" envDefault:""`
|
||||
RedirectURL string `env:"REDIRECT_URL" envDefault:""`
|
||||
}
|
||||
|
||||
// Provider is an interface that provides the OAuth2 flow for a specific provider
|
||||
// (e.g. Google, GitHub, etc.)
|
||||
//
|
||||
//go:generate mockery --name Provider --output=./mocks --filename provider.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Provider interface {
|
||||
// Name returns the name of the OAuth2 provider.
|
||||
Name() string
|
||||
|
||||
// State returns the current state for the OAuth2 flow.
|
||||
State() string
|
||||
|
||||
// RedirectURL returns the URL to redirect the user to after completing the OAuth2 flow.
|
||||
RedirectURL() string
|
||||
|
||||
// ErrorURL returns the URL to redirect the user to in case of an error during the OAuth2 flow.
|
||||
ErrorURL() string
|
||||
|
||||
// IsEnabled checks if the OAuth2 provider is enabled.
|
||||
IsEnabled() bool
|
||||
|
||||
// Exchange converts an authorization code into a token.
|
||||
Exchange(ctx context.Context, code string) (oauth2.Token, error)
|
||||
|
||||
// UserInfo retrieves the user's information using the access token.
|
||||
UserInfo(accessToken string) (users.User, error)
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package policies contains Magistrala policy definitions.
|
||||
package policies
|
||||
@@ -1,64 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package policies
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenKind = "token"
|
||||
GroupsKind = "groups"
|
||||
NewGroupKind = "new_group"
|
||||
ChannelsKind = "channels"
|
||||
NewChannelKind = "new_channel"
|
||||
ThingsKind = "things"
|
||||
NewThingKind = "new_thing"
|
||||
UsersKind = "users"
|
||||
DomainsKind = "domains"
|
||||
PlatformKind = "platform"
|
||||
)
|
||||
|
||||
const (
|
||||
GroupType = "group"
|
||||
ThingType = "thing"
|
||||
UserType = "user"
|
||||
DomainType = "domain"
|
||||
PlatformType = "platform"
|
||||
)
|
||||
|
||||
const (
|
||||
AdministratorRelation = "administrator"
|
||||
EditorRelation = "editor"
|
||||
ContributorRelation = "contributor"
|
||||
MemberRelation = "member"
|
||||
DomainRelation = "domain"
|
||||
ParentGroupRelation = "parent_group"
|
||||
RoleGroupRelation = "role_group"
|
||||
GroupRelation = "group"
|
||||
PlatformRelation = "platform"
|
||||
GuestRelation = "guest"
|
||||
)
|
||||
|
||||
const (
|
||||
AdminPermission = "admin"
|
||||
DeletePermission = "delete"
|
||||
EditPermission = "edit"
|
||||
ViewPermission = "view"
|
||||
MembershipPermission = "membership"
|
||||
SharePermission = "share"
|
||||
PublishPermission = "publish"
|
||||
SubscribePermission = "subscribe"
|
||||
CreatePermission = "create"
|
||||
)
|
||||
|
||||
const MagistralaObject = "magistrala"
|
||||
|
||||
//go:generate mockery --name Evaluator --output=./mocks --filename evaluator.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Evaluator interface {
|
||||
// CheckPolicy checks if the subject has a relation on the object.
|
||||
// It returns a non-nil error if the subject has no relation on
|
||||
// the object (which simply means the operation is denied).
|
||||
CheckPolicy(ctx context.Context, pr Policy) error
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
policies "github.com/absmach/magistrala/pkg/policies"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Evaluator is an autogenerated mock type for the Evaluator type
|
||||
type Evaluator struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// CheckPolicy provides a mock function with given fields: ctx, pr
|
||||
func (_m *Evaluator) CheckPolicy(ctx context.Context, pr policies.Policy) error {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CheckPolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// NewEvaluator creates a new instance of Evaluator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewEvaluator(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Evaluator {
|
||||
mock := &Evaluator{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,301 +0,0 @@
|
||||
// Code generated by mockery v2.43.2. DO NOT EDIT.
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
policies "github.com/absmach/magistrala/pkg/policies"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// Service is an autogenerated mock type for the Service type
|
||||
type Service struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// AddPolicies provides a mock function with given fields: ctx, prs
|
||||
func (_m *Service) AddPolicies(ctx context.Context, prs []policies.Policy) error {
|
||||
ret := _m.Called(ctx, prs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AddPolicies")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []policies.Policy) error); ok {
|
||||
r0 = rf(ctx, prs)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// AddPolicy provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) AddPolicy(ctx context.Context, pr policies.Policy) error {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AddPolicy")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// CountObjects provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) CountObjects(ctx context.Context, pr policies.Policy) (uint64, error) {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CountObjects")
|
||||
}
|
||||
|
||||
var r0 uint64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (uint64, error)); ok {
|
||||
return rf(ctx, pr)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) uint64); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(uint64)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
|
||||
r1 = rf(ctx, pr)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// CountSubjects provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) CountSubjects(ctx context.Context, pr policies.Policy) (uint64, error) {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CountSubjects")
|
||||
}
|
||||
|
||||
var r0 uint64
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (uint64, error)); ok {
|
||||
return rf(ctx, pr)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) uint64); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(uint64)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
|
||||
r1 = rf(ctx, pr)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// DeletePolicies provides a mock function with given fields: ctx, prs
|
||||
func (_m *Service) DeletePolicies(ctx context.Context, prs []policies.Policy) error {
|
||||
ret := _m.Called(ctx, prs)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeletePolicies")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []policies.Policy) error); ok {
|
||||
r0 = rf(ctx, prs)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// DeletePolicyFilter provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) DeletePolicyFilter(ctx context.Context, pr policies.Policy) error {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for DeletePolicyFilter")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) error); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// ListAllObjects provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) ListAllObjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListAllObjects")
|
||||
}
|
||||
|
||||
var r0 policies.PolicyPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (policies.PolicyPage, error)); ok {
|
||||
return rf(ctx, pr)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) policies.PolicyPage); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(policies.PolicyPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
|
||||
r1 = rf(ctx, pr)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListAllSubjects provides a mock function with given fields: ctx, pr
|
||||
func (_m *Service) ListAllSubjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
|
||||
ret := _m.Called(ctx, pr)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListAllSubjects")
|
||||
}
|
||||
|
||||
var r0 policies.PolicyPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) (policies.PolicyPage, error)); ok {
|
||||
return rf(ctx, pr)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy) policies.PolicyPage); ok {
|
||||
r0 = rf(ctx, pr)
|
||||
} else {
|
||||
r0 = ret.Get(0).(policies.PolicyPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy) error); ok {
|
||||
r1 = rf(ctx, pr)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListObjects provides a mock function with given fields: ctx, pr, nextPageToken, limit
|
||||
func (_m *Service) ListObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
|
||||
ret := _m.Called(ctx, pr, nextPageToken, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListObjects")
|
||||
}
|
||||
|
||||
var r0 policies.PolicyPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) (policies.PolicyPage, error)); ok {
|
||||
return rf(ctx, pr, nextPageToken, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) policies.PolicyPage); ok {
|
||||
r0 = rf(ctx, pr, nextPageToken, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(policies.PolicyPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, string, uint64) error); ok {
|
||||
r1 = rf(ctx, pr, nextPageToken, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListPermissions provides a mock function with given fields: ctx, pr, permissionsFilter
|
||||
func (_m *Service) ListPermissions(ctx context.Context, pr policies.Policy, permissionsFilter []string) (policies.Permissions, error) {
|
||||
ret := _m.Called(ctx, pr, permissionsFilter)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListPermissions")
|
||||
}
|
||||
|
||||
var r0 policies.Permissions
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, []string) (policies.Permissions, error)); ok {
|
||||
return rf(ctx, pr, permissionsFilter)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, []string) policies.Permissions); ok {
|
||||
r0 = rf(ctx, pr, permissionsFilter)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(policies.Permissions)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, []string) error); ok {
|
||||
r1 = rf(ctx, pr, permissionsFilter)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ListSubjects provides a mock function with given fields: ctx, pr, nextPageToken, limit
|
||||
func (_m *Service) ListSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
|
||||
ret := _m.Called(ctx, pr, nextPageToken, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListSubjects")
|
||||
}
|
||||
|
||||
var r0 policies.PolicyPage
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) (policies.PolicyPage, error)); ok {
|
||||
return rf(ctx, pr, nextPageToken, limit)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, policies.Policy, string, uint64) policies.PolicyPage); ok {
|
||||
r0 = rf(ctx, pr, nextPageToken, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(policies.PolicyPage)
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, policies.Policy, string, uint64) error); ok {
|
||||
r1 = rf(ctx, pr, nextPageToken, limit)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewService(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Service {
|
||||
mock := &Service{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package policies
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Policy struct {
|
||||
// Domain contains the domain ID.
|
||||
Domain string `json:"domain,omitempty"`
|
||||
|
||||
// Subject contains the subject ID or Token.
|
||||
Subject string `json:"subject"`
|
||||
|
||||
// SubjectType contains the subject type. Supported subject types are
|
||||
// platform, group, domain, thing, users.
|
||||
SubjectType string `json:"subject_type"`
|
||||
|
||||
// SubjectKind contains the subject kind. Supported subject kinds are
|
||||
// token, users, platform, things, channels, groups, domain.
|
||||
SubjectKind string `json:"subject_kind"`
|
||||
|
||||
// SubjectRelation contains subject relations.
|
||||
SubjectRelation string `json:"subject_relation,omitempty"`
|
||||
|
||||
// Object contains the object ID.
|
||||
Object string `json:"object"`
|
||||
|
||||
// ObjectKind contains the object kind. Supported object kinds are
|
||||
// users, platform, things, channels, groups, domain.
|
||||
ObjectKind string `json:"object_kind"`
|
||||
|
||||
// ObjectType contains the object type. Supported object types are
|
||||
// platform, group, domain, thing, users.
|
||||
ObjectType string `json:"object_type"`
|
||||
|
||||
// Relation contains the relation. Supported relations are administrator, editor, contributor, member, guest, parent_group,group,domain.
|
||||
Relation string `json:"relation,omitempty"`
|
||||
|
||||
// Permission contains the permission. Supported permissions are admin, delete, edit, share, view,
|
||||
// membership, create, admin_only, edit_only, view_only, membership_only, ext_admin, ext_edit, ext_view.
|
||||
Permission string `json:"permission,omitempty"`
|
||||
}
|
||||
|
||||
func (pr Policy) String() string {
|
||||
data, err := json.Marshal(pr)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
type PolicyPage struct {
|
||||
Policies []string
|
||||
NextPageToken string
|
||||
}
|
||||
|
||||
type Permissions []string
|
||||
|
||||
// PolicyService facilitates the communication to authorization
|
||||
// services and implements Authz functionalities for spicedb
|
||||
//
|
||||
//go:generate mockery --name Service --filename service.go --quiet --note "Copyright (c) Abstract Machines"
|
||||
type Service interface {
|
||||
// AddPolicy creates a policy for the given subject, so that, after
|
||||
// AddPolicy, `subject` has a `relation` on `object`. Returns a non-nil
|
||||
// error in case of failures.
|
||||
AddPolicy(ctx context.Context, pr Policy) error
|
||||
|
||||
// AddPolicies adds new policies for given subjects. This method is
|
||||
// only allowed to use as an admin.
|
||||
AddPolicies(ctx context.Context, prs []Policy) error
|
||||
|
||||
// DeletePolicyFilter removes policy for given policy filter request.
|
||||
DeletePolicyFilter(ctx context.Context, pr Policy) error
|
||||
|
||||
// DeletePolicies deletes policies for given subjects. This method is
|
||||
// only allowed to use as an admin.
|
||||
DeletePolicies(ctx context.Context, prs []Policy) error
|
||||
|
||||
// ListObjects lists policies based on the given Policy structure.
|
||||
ListObjects(ctx context.Context, pr Policy, nextPageToken string, limit uint64) (PolicyPage, error)
|
||||
|
||||
// ListAllObjects lists all policies based on the given Policy structure.
|
||||
ListAllObjects(ctx context.Context, pr Policy) (PolicyPage, error)
|
||||
|
||||
// CountObjects count policies based on the given Policy structure.
|
||||
CountObjects(ctx context.Context, pr Policy) (uint64, error)
|
||||
|
||||
// ListSubjects lists subjects based on the given Policy structure.
|
||||
ListSubjects(ctx context.Context, pr Policy, nextPageToken string, limit uint64) (PolicyPage, error)
|
||||
|
||||
// ListAllSubjects lists all subjects based on the given Policy structure.
|
||||
ListAllSubjects(ctx context.Context, pr Policy) (PolicyPage, error)
|
||||
|
||||
// CountSubjects count policies based on the given Policy structure.
|
||||
CountSubjects(ctx context.Context, pr Policy) (uint64, error)
|
||||
|
||||
// ListPermissions lists permission betweeen given subject and object .
|
||||
ListPermissions(ctx context.Context, pr Policy, permissionsFilter []string) (Permissions, error)
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package server contains the HTTP, gRPC and CoAP server implementation.
|
||||
package spicedb
|
||||
@@ -1,64 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package spicedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
|
||||
"github.com/authzed/authzed-go/v1"
|
||||
)
|
||||
|
||||
type policyEvaluator struct {
|
||||
client *authzed.ClientWithExperimental
|
||||
permissionClient v1.PermissionsServiceClient
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewPolicyEvaluator(client *authzed.ClientWithExperimental, logger *slog.Logger) policies.Evaluator {
|
||||
return &policyEvaluator{
|
||||
client: client,
|
||||
permissionClient: client.PermissionsServiceClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (pe *policyEvaluator) CheckPolicy(ctx context.Context, pr policies.Policy) error {
|
||||
checkReq := v1.CheckPermissionRequest{
|
||||
// FullyConsistent means little caching will be available, which means performance will suffer.
|
||||
// Only use if a ZedToken is not available or absolutely latest information is required.
|
||||
// If we want to avoid FullyConsistent and to improve the performance of spicedb, then we need to cache the ZEDTOKEN whenever RELATIONS is created or updated.
|
||||
// Instead of using FullyConsistent we need to use Consistency_AtLeastAsFresh, code looks like below one.
|
||||
// Consistency: &v1.Consistency{
|
||||
// Requirement: &v1.Consistency_AtLeastAsFresh{
|
||||
// AtLeastAsFresh: getRelationTupleZedTokenFromCache() ,
|
||||
// }
|
||||
// },
|
||||
// Reference: https://authzed.com/docs/reference/api-consistency
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Permission: pr.Permission,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
}
|
||||
|
||||
resp, err := pe.permissionClient.CheckPermission(ctx, &checkReq)
|
||||
if err != nil {
|
||||
return handleSpicedbError(err)
|
||||
}
|
||||
if resp.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
|
||||
return nil
|
||||
}
|
||||
if reason, ok := v1.CheckPermissionResponse_Permissionship_name[int32(resp.Permissionship)]; ok {
|
||||
return errors.Wrap(svcerr.ErrAuthorization, errors.New(reason))
|
||||
}
|
||||
return svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -1,950 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package spicedb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
|
||||
"github.com/authzed/authzed-go/v1"
|
||||
gstatus "google.golang.org/genproto/googleapis/rpc/status"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const defRetrieveAllLimit = 1000
|
||||
|
||||
var (
|
||||
errInvalidSubject = errors.New("invalid subject kind")
|
||||
errAddPolicies = errors.New("failed to add policies")
|
||||
errRetrievePolicies = errors.New("failed to retrieve policies")
|
||||
errRemovePolicies = errors.New("failed to remove the policies")
|
||||
errNoPolicies = errors.New("no policies provided")
|
||||
errInternal = errors.New("spicedb internal error")
|
||||
errPlatform = errors.New("invalid platform id")
|
||||
)
|
||||
|
||||
var (
|
||||
defThingsFilterPermissions = []string{
|
||||
policies.AdminPermission,
|
||||
policies.DeletePermission,
|
||||
policies.EditPermission,
|
||||
policies.ViewPermission,
|
||||
policies.SharePermission,
|
||||
policies.PublishPermission,
|
||||
policies.SubscribePermission,
|
||||
}
|
||||
|
||||
defGroupsFilterPermissions = []string{
|
||||
policies.AdminPermission,
|
||||
policies.DeletePermission,
|
||||
policies.EditPermission,
|
||||
policies.ViewPermission,
|
||||
policies.MembershipPermission,
|
||||
policies.SharePermission,
|
||||
}
|
||||
|
||||
defDomainsFilterPermissions = []string{
|
||||
policies.AdminPermission,
|
||||
policies.EditPermission,
|
||||
policies.ViewPermission,
|
||||
policies.MembershipPermission,
|
||||
policies.SharePermission,
|
||||
}
|
||||
|
||||
defPlatformFilterPermissions = []string{
|
||||
policies.AdminPermission,
|
||||
policies.MembershipPermission,
|
||||
}
|
||||
)
|
||||
|
||||
type policyService struct {
|
||||
client *authzed.ClientWithExperimental
|
||||
permissionClient v1.PermissionsServiceClient
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewPolicyService(client *authzed.ClientWithExperimental, logger *slog.Logger) policies.Service {
|
||||
return &policyService{
|
||||
client: client,
|
||||
permissionClient: client.PermissionsServiceClient,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *policyService) AddPolicy(ctx context.Context, pr policies.Policy) error {
|
||||
if err := ps.policyValidation(pr); err != nil {
|
||||
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
|
||||
}
|
||||
precond, err := ps.addPolicyPreCondition(ctx, pr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updates := []*v1.RelationshipUpdate{
|
||||
{
|
||||
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
|
||||
Relationship: &v1.Relationship{
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Relation: pr.Relation,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: precond})
|
||||
if err != nil {
|
||||
return errors.Wrap(errAddPolicies, handleSpicedbError(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *policyService) AddPolicies(ctx context.Context, prs []policies.Policy) error {
|
||||
updates := []*v1.RelationshipUpdate{}
|
||||
var preconds []*v1.Precondition
|
||||
for _, pr := range prs {
|
||||
if err := ps.policyValidation(pr); err != nil {
|
||||
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
|
||||
}
|
||||
precond, err := ps.addPolicyPreCondition(ctx, pr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
preconds = append(preconds, precond...)
|
||||
updates = append(updates, &v1.RelationshipUpdate{
|
||||
Operation: v1.RelationshipUpdate_OPERATION_CREATE,
|
||||
Relationship: &v1.Relationship{
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Relation: pr.Relation,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
},
|
||||
})
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errNoPolicies)
|
||||
}
|
||||
_, err := ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates, OptionalPreconditions: preconds})
|
||||
if err != nil {
|
||||
return errors.Wrap(errAddPolicies, handleSpicedbError(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *policyService) DeletePolicyFilter(ctx context.Context, pr policies.Policy) error {
|
||||
req := &v1.DeleteRelationshipsRequest{
|
||||
RelationshipFilter: &v1.RelationshipFilter{
|
||||
ResourceType: pr.ObjectType,
|
||||
OptionalResourceId: pr.Object,
|
||||
},
|
||||
}
|
||||
|
||||
if pr.Relation != "" {
|
||||
req.RelationshipFilter.OptionalRelation = pr.Relation
|
||||
}
|
||||
|
||||
if pr.SubjectType != "" {
|
||||
req.RelationshipFilter.OptionalSubjectFilter = &v1.SubjectFilter{
|
||||
SubjectType: pr.SubjectType,
|
||||
}
|
||||
if pr.Subject != "" {
|
||||
req.RelationshipFilter.OptionalSubjectFilter.OptionalSubjectId = pr.Subject
|
||||
}
|
||||
if pr.SubjectRelation != "" {
|
||||
req.RelationshipFilter.OptionalSubjectFilter.OptionalRelation = &v1.SubjectFilter_RelationFilter{
|
||||
Relation: pr.SubjectRelation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := ps.permissionClient.DeleteRelationships(ctx, req); err != nil {
|
||||
return errors.Wrap(errRemovePolicies, handleSpicedbError(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *policyService) DeletePolicies(ctx context.Context, prs []policies.Policy) error {
|
||||
updates := []*v1.RelationshipUpdate{}
|
||||
for _, pr := range prs {
|
||||
if err := ps.policyValidation(pr); err != nil {
|
||||
return errors.Wrap(svcerr.ErrInvalidPolicy, err)
|
||||
}
|
||||
updates = append(updates, &v1.RelationshipUpdate{
|
||||
Operation: v1.RelationshipUpdate_OPERATION_DELETE,
|
||||
Relationship: &v1.Relationship{
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Relation: pr.Relation,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
},
|
||||
})
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errNoPolicies)
|
||||
}
|
||||
_, err := ps.permissionClient.WriteRelationships(ctx, &v1.WriteRelationshipsRequest{Updates: updates})
|
||||
if err != nil {
|
||||
return errors.Wrap(errRemovePolicies, handleSpicedbError(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *policyService) ListObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
res, npt, err := ps.retrieveObjects(ctx, pr, nextPageToken, limit)
|
||||
if err != nil {
|
||||
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
var page policies.PolicyPage
|
||||
for _, tuple := range res {
|
||||
page.Policies = append(page.Policies, tuple.Object)
|
||||
}
|
||||
page.NextPageToken = npt
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) ListAllObjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
|
||||
res, err := ps.retrieveAllObjects(ctx, pr)
|
||||
if err != nil {
|
||||
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
var page policies.PolicyPage
|
||||
for _, tuple := range res {
|
||||
page.Policies = append(page.Policies, tuple.Object)
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) CountObjects(ctx context.Context, pr policies.Policy) (uint64, error) {
|
||||
var count uint64
|
||||
nextPageToken := ""
|
||||
for {
|
||||
relationTuples, npt, err := ps.retrieveObjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
|
||||
if err != nil {
|
||||
return count, err
|
||||
}
|
||||
count = count + uint64(len(relationTuples))
|
||||
if npt == "" {
|
||||
break
|
||||
}
|
||||
nextPageToken = npt
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) ListSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) (policies.PolicyPage, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
res, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, limit)
|
||||
if err != nil {
|
||||
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
var page policies.PolicyPage
|
||||
for _, tuple := range res {
|
||||
page.Policies = append(page.Policies, tuple.Subject)
|
||||
}
|
||||
page.NextPageToken = npt
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) ListAllSubjects(ctx context.Context, pr policies.Policy) (policies.PolicyPage, error) {
|
||||
res, err := ps.retrieveAllSubjects(ctx, pr)
|
||||
if err != nil {
|
||||
return policies.PolicyPage{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
var page policies.PolicyPage
|
||||
for _, tuple := range res {
|
||||
page.Policies = append(page.Policies, tuple.Subject)
|
||||
}
|
||||
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) CountSubjects(ctx context.Context, pr policies.Policy) (uint64, error) {
|
||||
var count uint64
|
||||
nextPageToken := ""
|
||||
for {
|
||||
relationTuples, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
|
||||
if err != nil {
|
||||
return count, err
|
||||
}
|
||||
count = count + uint64(len(relationTuples))
|
||||
if npt == "" {
|
||||
break
|
||||
}
|
||||
nextPageToken = npt
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) ListPermissions(ctx context.Context, pr policies.Policy, permissionsFilter []string) (policies.Permissions, error) {
|
||||
if len(permissionsFilter) == 0 {
|
||||
switch pr.ObjectType {
|
||||
case policies.ThingType:
|
||||
permissionsFilter = defThingsFilterPermissions
|
||||
case policies.GroupType:
|
||||
permissionsFilter = defGroupsFilterPermissions
|
||||
case policies.PlatformType:
|
||||
permissionsFilter = defPlatformFilterPermissions
|
||||
case policies.DomainType:
|
||||
permissionsFilter = defDomainsFilterPermissions
|
||||
default:
|
||||
return nil, svcerr.ErrMalformedEntity
|
||||
}
|
||||
}
|
||||
pers, err := ps.retrievePermissions(ctx, pr, permissionsFilter)
|
||||
if err != nil {
|
||||
return []string{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return pers, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) policyValidation(pr policies.Policy) error {
|
||||
if pr.ObjectType == policies.PlatformType && pr.Object != policies.MagistralaObject {
|
||||
return errPlatform
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *policyService) addPolicyPreCondition(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
// Checks are required for following ( -> means adding)
|
||||
// 1.) user -> group (both user groups and channels)
|
||||
// 2.) user -> thing
|
||||
// 3.) group -> group (both for adding parent_group and channels)
|
||||
// 4.) group (channel) -> thing
|
||||
// 5.) user -> domain
|
||||
|
||||
switch {
|
||||
// 1.) user -> group (both user groups and channels)
|
||||
// Checks :
|
||||
// - USER with ANY RELATION to DOMAIN
|
||||
// - GROUP with DOMAIN RELATION to DOMAIN
|
||||
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.GroupType:
|
||||
return ps.userGroupPreConditions(ctx, pr)
|
||||
|
||||
// 2.) user -> thing
|
||||
// Checks :
|
||||
// - USER with ANY RELATION to DOMAIN
|
||||
// - THING with DOMAIN RELATION to DOMAIN
|
||||
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.ThingType:
|
||||
return ps.userThingPreConditions(ctx, pr)
|
||||
|
||||
// 3.) group -> group (both for adding parent_group and channels)
|
||||
// Checks :
|
||||
// - CHILD_GROUP with out PARENT_GROUP RELATION with any GROUP
|
||||
case pr.SubjectType == policies.GroupType && pr.ObjectType == policies.GroupType:
|
||||
return groupPreConditions(pr)
|
||||
|
||||
// 4.) group (channel) -> thing
|
||||
// Checks :
|
||||
// - GROUP (channel) with DOMAIN RELATION to DOMAIN
|
||||
// - NO GROUP should not have PARENT_GROUP RELATION with GROUP (channel)
|
||||
// - THING with DOMAIN RELATION to DOMAIN
|
||||
case pr.SubjectType == policies.GroupType && pr.ObjectType == policies.ThingType:
|
||||
return channelThingPreCondition(pr)
|
||||
|
||||
// 5.) user -> domain
|
||||
// Checks :
|
||||
// - User doesn't have any relation with domain
|
||||
case pr.SubjectType == policies.UserType && pr.ObjectType == policies.DomainType:
|
||||
return ps.userDomainPreConditions(ctx, pr)
|
||||
|
||||
// Check thing and group not belongs to other domain before adding to domain
|
||||
case pr.SubjectType == policies.DomainType && pr.Relation == policies.DomainRelation && (pr.ObjectType == policies.ThingType || pr.ObjectType == policies.GroupType):
|
||||
preconds := []*v1.Precondition{
|
||||
{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: pr.ObjectType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return preconds, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) userGroupPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
var preconds []*v1.Precondition
|
||||
|
||||
// user should not have any relation with group
|
||||
preconds = append(preconds, &v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.UserType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
})
|
||||
isSuperAdmin := false
|
||||
if err := ps.checkPolicy(ctx, policies.Policy{
|
||||
Subject: pr.Subject,
|
||||
SubjectType: pr.SubjectType,
|
||||
Permission: policies.AdminPermission,
|
||||
Object: policies.MagistralaObject,
|
||||
ObjectType: policies.PlatformType,
|
||||
}); err == nil {
|
||||
isSuperAdmin = true
|
||||
}
|
||||
|
||||
if !isSuperAdmin {
|
||||
preconds = append(preconds, &v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.DomainType,
|
||||
OptionalResourceId: pr.Domain,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.UserType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
switch {
|
||||
case pr.ObjectKind == policies.NewGroupKind || pr.ObjectKind == policies.NewChannelKind:
|
||||
preconds = append(preconds,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
preconds = append(preconds,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return preconds, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) userThingPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
var preconds []*v1.Precondition
|
||||
|
||||
// user should not have any relation with thing
|
||||
preconds = append(preconds, &v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.ThingType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.UserType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
isSuperAdmin := false
|
||||
if err := ps.checkPolicy(ctx, policies.Policy{
|
||||
Subject: pr.Subject,
|
||||
SubjectType: pr.SubjectType,
|
||||
Permission: policies.AdminPermission,
|
||||
Object: policies.MagistralaObject,
|
||||
ObjectType: policies.PlatformType,
|
||||
}); err == nil {
|
||||
isSuperAdmin = true
|
||||
}
|
||||
|
||||
if !isSuperAdmin {
|
||||
preconds = append(preconds, &v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.DomainType,
|
||||
OptionalResourceId: pr.Domain,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.UserType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
switch {
|
||||
// For New thing
|
||||
// - THING without DOMAIN RELATION to ANY DOMAIN
|
||||
case pr.ObjectKind == policies.NewThingKind:
|
||||
preconds = append(preconds,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.ThingType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
// For existing thing
|
||||
// - THING without DOMAIN RELATION to ANY DOMAIN
|
||||
preconds = append(preconds,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.ThingType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
return preconds, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) userDomainPreConditions(ctx context.Context, pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
var preconds []*v1.Precondition
|
||||
|
||||
if err := ps.checkPolicy(ctx, policies.Policy{
|
||||
Subject: pr.Subject,
|
||||
SubjectType: pr.SubjectType,
|
||||
Permission: policies.AdminPermission,
|
||||
Object: policies.MagistralaObject,
|
||||
ObjectType: policies.PlatformType,
|
||||
}); err == nil {
|
||||
return preconds, fmt.Errorf("use already exists in domain")
|
||||
}
|
||||
|
||||
// user should not have any relation with domain.
|
||||
preconds = append(preconds, &v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.DomainType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.UserType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return preconds, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) checkPolicy(ctx context.Context, pr policies.Policy) error {
|
||||
checkReq := v1.CheckPermissionRequest{
|
||||
// FullyConsistent means little caching will be available, which means performance will suffer.
|
||||
// Only use if a ZedToken is not available or absolutely latest information is required.
|
||||
// If we want to avoid FullyConsistent and to improve the performance of spicedb, then we need to cache the ZEDTOKEN whenever RELATIONS is created or updated.
|
||||
// Instead of using FullyConsistent we need to use Consistency_AtLeastAsFresh, code looks like below one.
|
||||
// Consistency: &v1.Consistency{
|
||||
// Requirement: &v1.Consistency_AtLeastAsFresh{
|
||||
// AtLeastAsFresh: getRelationTupleZedTokenFromCache() ,
|
||||
// }
|
||||
// },
|
||||
// Reference: https://authzed.com/docs/reference/api-consistency
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Permission: pr.Permission,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
}
|
||||
|
||||
resp, err := ps.permissionClient.CheckPermission(ctx, &checkReq)
|
||||
if err != nil {
|
||||
return handleSpicedbError(err)
|
||||
}
|
||||
if resp.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
|
||||
return nil
|
||||
}
|
||||
if reason, ok := v1.CheckPermissionResponse_Permissionship_name[int32(resp.Permissionship)]; ok {
|
||||
return errors.Wrap(svcerr.ErrAuthorization, errors.New(reason))
|
||||
}
|
||||
return svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
func (ps *policyService) retrieveObjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) ([]policies.Policy, string, error) {
|
||||
resourceReq := &v1.LookupResourcesRequest{
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
ResourceObjectType: pr.ObjectType,
|
||||
Permission: pr.Permission,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
OptionalLimit: uint32(limit),
|
||||
}
|
||||
if nextPageToken != "" {
|
||||
resourceReq.OptionalCursor = &v1.Cursor{Token: nextPageToken}
|
||||
}
|
||||
stream, err := ps.permissionClient.LookupResources(ctx, resourceReq)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
resources := []*v1.LookupResourcesResponse{}
|
||||
var token string
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
switch err {
|
||||
case nil:
|
||||
resources = append(resources, resp)
|
||||
case io.EOF:
|
||||
if len(resources) > 0 && resources[len(resources)-1].AfterResultCursor != nil {
|
||||
token = resources[len(resources)-1].AfterResultCursor.Token
|
||||
}
|
||||
return objectsToAuthPolicies(resources), token, nil
|
||||
default:
|
||||
if len(resources) > 0 && resources[len(resources)-1].AfterResultCursor != nil {
|
||||
token = resources[len(resources)-1].AfterResultCursor.Token
|
||||
}
|
||||
return []policies.Policy{}, token, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *policyService) retrieveAllObjects(ctx context.Context, pr policies.Policy) ([]policies.Policy, error) {
|
||||
resourceReq := &v1.LookupResourcesRequest{
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
ResourceObjectType: pr.ObjectType,
|
||||
Permission: pr.Permission,
|
||||
Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: pr.SubjectType, ObjectId: pr.Subject}, OptionalRelation: pr.SubjectRelation},
|
||||
}
|
||||
stream, err := ps.permissionClient.LookupResources(ctx, resourceReq)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
tuples := []policies.Policy{}
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
switch {
|
||||
case errors.Contains(err, io.EOF):
|
||||
return tuples, nil
|
||||
case err != nil:
|
||||
return tuples, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
default:
|
||||
tuples = append(tuples, policies.Policy{Object: resp.ResourceObjectId})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *policyService) retrieveSubjects(ctx context.Context, pr policies.Policy, nextPageToken string, limit uint64) ([]policies.Policy, string, error) {
|
||||
subjectsReq := v1.LookupSubjectsRequest{
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
Resource: &v1.ObjectReference{ObjectType: pr.ObjectType, ObjectId: pr.Object},
|
||||
Permission: pr.Permission,
|
||||
SubjectObjectType: pr.SubjectType,
|
||||
OptionalSubjectRelation: pr.SubjectRelation,
|
||||
OptionalConcreteLimit: uint32(limit),
|
||||
WildcardOption: v1.LookupSubjectsRequest_WILDCARD_OPTION_INCLUDE_WILDCARDS,
|
||||
}
|
||||
if nextPageToken != "" {
|
||||
subjectsReq.OptionalCursor = &v1.Cursor{Token: nextPageToken}
|
||||
}
|
||||
stream, err := ps.permissionClient.LookupSubjects(ctx, &subjectsReq)
|
||||
if err != nil {
|
||||
return nil, "", errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
subjects := []*v1.LookupSubjectsResponse{}
|
||||
var token string
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
|
||||
switch err {
|
||||
case nil:
|
||||
subjects = append(subjects, resp)
|
||||
case io.EOF:
|
||||
if len(subjects) > 0 && subjects[len(subjects)-1].AfterResultCursor != nil {
|
||||
token = subjects[len(subjects)-1].AfterResultCursor.Token
|
||||
}
|
||||
return subjectsToAuthPolicies(subjects), token, nil
|
||||
default:
|
||||
if len(subjects) > 0 && subjects[len(subjects)-1].AfterResultCursor != nil {
|
||||
token = subjects[len(subjects)-1].AfterResultCursor.Token
|
||||
}
|
||||
return []policies.Policy{}, token, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *policyService) retrieveAllSubjects(ctx context.Context, pr policies.Policy) ([]policies.Policy, error) {
|
||||
var tuples []policies.Policy
|
||||
nextPageToken := ""
|
||||
for i := 0; ; i++ {
|
||||
relationTuples, npt, err := ps.retrieveSubjects(ctx, pr, nextPageToken, defRetrieveAllLimit)
|
||||
if err != nil {
|
||||
return tuples, err
|
||||
}
|
||||
tuples = append(tuples, relationTuples...)
|
||||
if npt == "" || (len(tuples) < defRetrieveAllLimit) {
|
||||
break
|
||||
}
|
||||
nextPageToken = npt
|
||||
}
|
||||
return tuples, nil
|
||||
}
|
||||
|
||||
func (ps *policyService) retrievePermissions(ctx context.Context, pr policies.Policy, filterPermission []string) (policies.Permissions, error) {
|
||||
var permissionChecks []*v1.CheckBulkPermissionsRequestItem
|
||||
for _, fp := range filterPermission {
|
||||
permissionChecks = append(permissionChecks, &v1.CheckBulkPermissionsRequestItem{
|
||||
Resource: &v1.ObjectReference{
|
||||
ObjectType: pr.ObjectType,
|
||||
ObjectId: pr.Object,
|
||||
},
|
||||
Permission: fp,
|
||||
Subject: &v1.SubjectReference{
|
||||
Object: &v1.ObjectReference{
|
||||
ObjectType: pr.SubjectType,
|
||||
ObjectId: pr.Subject,
|
||||
},
|
||||
OptionalRelation: pr.SubjectRelation,
|
||||
},
|
||||
})
|
||||
}
|
||||
resp, err := ps.client.PermissionsServiceClient.CheckBulkPermissions(ctx, &v1.CheckBulkPermissionsRequest{
|
||||
Consistency: &v1.Consistency{
|
||||
Requirement: &v1.Consistency_FullyConsistent{
|
||||
FullyConsistent: true,
|
||||
},
|
||||
},
|
||||
Items: permissionChecks,
|
||||
})
|
||||
if err != nil {
|
||||
return policies.Permissions{}, errors.Wrap(errRetrievePolicies, handleSpicedbError(err))
|
||||
}
|
||||
|
||||
permissions := []string{}
|
||||
for _, pair := range resp.Pairs {
|
||||
if pair.GetError() != nil {
|
||||
s := pair.GetError()
|
||||
return policies.Permissions{}, errors.Wrap(errRetrievePolicies, convertGRPCStatusToError(convertToGrpcStatus(s)))
|
||||
}
|
||||
item := pair.GetItem()
|
||||
req := pair.GetRequest()
|
||||
if item != nil && req != nil && item.Permissionship == v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION {
|
||||
permissions = append(permissions, req.GetPermission())
|
||||
}
|
||||
}
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
func groupPreConditions(pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
// - PARENT_GROUP (subject) with DOMAIN RELATION to DOMAIN
|
||||
precond := []*v1.Precondition{
|
||||
{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Subject,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if pr.ObjectKind != policies.ChannelsKind {
|
||||
precond = append(precond,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.ParentGroupRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.GroupType,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
switch {
|
||||
// - NEW CHILD_GROUP (object) with out DOMAIN RELATION to ANY DOMAIN
|
||||
case pr.ObjectType == policies.GroupType && pr.ObjectKind == policies.NewGroupKind:
|
||||
precond = append(precond,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
// - CHILD_GROUP (object) with DOMAIN RELATION to DOMAIN
|
||||
precond = append(precond,
|
||||
&v1.Precondition{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
return precond, nil
|
||||
}
|
||||
|
||||
func channelThingPreCondition(pr policies.Policy) ([]*v1.Precondition, error) {
|
||||
if pr.SubjectKind != policies.ChannelsKind {
|
||||
return nil, errors.Wrap(errors.ErrMalformedEntity, errInvalidSubject)
|
||||
}
|
||||
precond := []*v1.Precondition{
|
||||
{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalResourceId: pr.Subject,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Operation: v1.Precondition_OPERATION_MUST_NOT_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.GroupType,
|
||||
OptionalRelation: policies.ParentGroupRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.GroupType,
|
||||
OptionalSubjectId: pr.Subject,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Operation: v1.Precondition_OPERATION_MUST_MATCH,
|
||||
Filter: &v1.RelationshipFilter{
|
||||
ResourceType: policies.ThingType,
|
||||
OptionalResourceId: pr.Object,
|
||||
OptionalRelation: policies.DomainRelation,
|
||||
OptionalSubjectFilter: &v1.SubjectFilter{
|
||||
SubjectType: policies.DomainType,
|
||||
OptionalSubjectId: pr.Domain,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
return precond, nil
|
||||
}
|
||||
|
||||
func objectsToAuthPolicies(objects []*v1.LookupResourcesResponse) []policies.Policy {
|
||||
var policyList []policies.Policy
|
||||
for _, obj := range objects {
|
||||
policyList = append(policyList, policies.Policy{
|
||||
Object: obj.GetResourceObjectId(),
|
||||
})
|
||||
}
|
||||
return policyList
|
||||
}
|
||||
|
||||
func subjectsToAuthPolicies(subjects []*v1.LookupSubjectsResponse) []policies.Policy {
|
||||
var policyList []policies.Policy
|
||||
for _, sub := range subjects {
|
||||
policyList = append(policyList, policies.Policy{
|
||||
Subject: sub.Subject.GetSubjectObjectId(),
|
||||
})
|
||||
}
|
||||
return policyList
|
||||
}
|
||||
|
||||
func handleSpicedbError(err error) error {
|
||||
if st, ok := status.FromError(err); ok {
|
||||
return convertGRPCStatusToError(st)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func convertToGrpcStatus(gst *gstatus.Status) *status.Status {
|
||||
st := status.New(codes.Code(gst.Code), gst.GetMessage())
|
||||
return st
|
||||
}
|
||||
|
||||
func convertGRPCStatusToError(st *status.Status) error {
|
||||
switch st.Code() {
|
||||
case codes.NotFound:
|
||||
return errors.Wrap(repoerr.ErrNotFound, errors.New(st.Message()))
|
||||
case codes.InvalidArgument:
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
|
||||
case codes.AlreadyExists:
|
||||
return errors.Wrap(repoerr.ErrConflict, errors.New(st.Message()))
|
||||
case codes.Unauthenticated:
|
||||
return errors.Wrap(svcerr.ErrAuthentication, errors.New(st.Message()))
|
||||
case codes.Internal:
|
||||
return errors.Wrap(errInternal, errors.New(st.Message()))
|
||||
case codes.OK:
|
||||
if msg := st.Message(); msg != "" {
|
||||
return errors.Wrap(errors.ErrUnidentified, errors.New(msg))
|
||||
}
|
||||
return nil
|
||||
case codes.FailedPrecondition:
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
|
||||
case codes.PermissionDenied:
|
||||
return errors.Wrap(svcerr.ErrAuthorization, errors.New(st.Message()))
|
||||
default:
|
||||
return errors.Wrap(fmt.Errorf("unexpected gRPC status: %s (status code:%v)", st.Code().String(), st.Code()), errors.New(st.Message()))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user