mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 06:50:18 +00:00
MG-37 - Add Rules Engine tests (#74)
* add service and endpoint tests Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * update github workflows Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix failing linter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove unused field Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove logs Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * update github workflows Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix time format Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * update to plural Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> --------- Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
@@ -110,6 +110,11 @@ jobs:
|
||||
- "things/**"
|
||||
- "auth/**"
|
||||
|
||||
re:
|
||||
- "re/**"
|
||||
- "cmd/re/**"
|
||||
- "re/api/**"
|
||||
|
||||
- name: Create coverage directory
|
||||
run: |
|
||||
mkdir coverage
|
||||
@@ -154,6 +159,12 @@ jobs:
|
||||
run: |
|
||||
go test --race -v -count=1 -coverprofile=coverage/readers.out ./readers/...
|
||||
|
||||
- name: Run rule engine tests
|
||||
if: steps.changes.outputs.re == 'true' || steps.changes.outputs.workflow == 'true'
|
||||
run: |
|
||||
go test --race -v -count=1 -coverprofile=coverage/re.out ./re/...
|
||||
|
||||
|
||||
- name: Upload coverage
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
|
||||
+4
-1
@@ -33,6 +33,7 @@ import (
|
||||
httpserver "github.com/absmach/supermq/pkg/server/http"
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/caarlos0/env/v11"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -175,7 +176,9 @@ func main() {
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, cfg.InstanceID), logger)
|
||||
mux := chi.NewRouter()
|
||||
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
|
||||
@@ -0,0 +1,997 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/0x6flab/namegenerator"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/magistrala/re/api"
|
||||
"github.com/absmach/magistrala/re/mocks"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
const contentType = "application/json"
|
||||
|
||||
var (
|
||||
namegen = namegenerator.NewGenerator()
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
validToken = "valid"
|
||||
invalidToken = "invalid"
|
||||
now = time.Now().UTC().Truncate(time.Minute)
|
||||
schedule = re.Schedule{
|
||||
StartDateTime: now.Add(-1 * time.Hour),
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
}
|
||||
rule = re.Rule{
|
||||
ID: validID,
|
||||
Name: namegen.Generate(),
|
||||
DomainID: domainID,
|
||||
Schedule: schedule,
|
||||
Metadata: re.Metadata{
|
||||
"name": "test",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type testRequest struct {
|
||||
client *http.Client
|
||||
method string
|
||||
url string
|
||||
contentType string
|
||||
token string
|
||||
body io.Reader
|
||||
}
|
||||
|
||||
func (tr testRequest) make() (*http.Response, error) {
|
||||
req, err := http.NewRequest(tr.method, tr.url, tr.body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tr.token != "" {
|
||||
req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token)
|
||||
}
|
||||
|
||||
if tr.contentType != "" {
|
||||
req.Header.Set("Content-Type", tr.contentType)
|
||||
}
|
||||
|
||||
req.Header.Set("Referer", "http://localhost")
|
||||
|
||||
return tr.client.Do(req)
|
||||
}
|
||||
|
||||
func newRuleEngineServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) {
|
||||
svc := new(mocks.Service)
|
||||
authn := new(authnmocks.Authentication)
|
||||
|
||||
logger := smqlog.NewMock()
|
||||
mux := chi.NewRouter()
|
||||
api.MakeHandler(svc, authn, mux, logger, "")
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
func toJSON(data any) string {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(jsonData)
|
||||
}
|
||||
|
||||
func TestAddRuleEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
rule re.Rule
|
||||
domainID string
|
||||
token string
|
||||
contentType string
|
||||
status int
|
||||
authnRes smqauthn.Session
|
||||
authnErr error
|
||||
svcRes re.Rule
|
||||
svcErr error
|
||||
err error
|
||||
len int
|
||||
}{
|
||||
{
|
||||
desc: "add rule successfully",
|
||||
rule: rule,
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
domainID: domainID,
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
status: http.StatusCreated,
|
||||
svcRes: rule,
|
||||
},
|
||||
{
|
||||
desc: "add rule with invalid token",
|
||||
rule: rule,
|
||||
token: invalidToken,
|
||||
authnRes: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
contentType: contentType,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "add rule with empty token",
|
||||
token: "",
|
||||
authnRes: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
rule: rule,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "add rule with name that is too long",
|
||||
token: validToken,
|
||||
rule: re.Rule{
|
||||
ID: validID,
|
||||
Name: strings.Repeat("a", 1025),
|
||||
Logic: re.Script{
|
||||
Type: re.ScriptType(0),
|
||||
Value: "return `test` end",
|
||||
},
|
||||
},
|
||||
domainID: domainID,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "add rule with empty domainID",
|
||||
token: validToken,
|
||||
rule: rule,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "add rule with invalid content type",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
rule: rule,
|
||||
contentType: "application/xml",
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "add rule with service error",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
rule: rule,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
data := toJSON(tc.rule)
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/%s/rules", ts.URL, tc.domainID),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(data),
|
||||
}
|
||||
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("AddRule", mock.Anything, tc.authnRes, tc.rule).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var errRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewRuleEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
domainID string
|
||||
token string
|
||||
contentType string
|
||||
status int
|
||||
authnRes smqauthn.Session
|
||||
authnErr error
|
||||
svcRes re.Rule
|
||||
svcErr error
|
||||
err error
|
||||
len int
|
||||
}{
|
||||
{
|
||||
desc: "view rule successfully",
|
||||
id: rule.ID,
|
||||
token: validToken,
|
||||
contentType: contentType,
|
||||
domainID: domainID,
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
status: http.StatusOK,
|
||||
svcRes: rule,
|
||||
},
|
||||
{
|
||||
desc: "view rule with invalid token",
|
||||
id: rule.ID,
|
||||
token: invalidToken,
|
||||
authnRes: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
contentType: contentType,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "view rule with empty token",
|
||||
token: "",
|
||||
authnRes: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: rule.ID,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "view rule with empty domainID",
|
||||
token: validToken,
|
||||
id: rule.ID,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "view rule with service error",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
id: rule.ID,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodGet,
|
||||
url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var errRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRulesEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
query string
|
||||
domainID string
|
||||
token string
|
||||
session smqauthn.Session
|
||||
listRulesResponse re.Page
|
||||
status int
|
||||
authnErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list rules successfully",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
status: http.StatusOK,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with empty token",
|
||||
domainID: domainID,
|
||||
token: "",
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "list rules with invalid token",
|
||||
domainID: domainID,
|
||||
token: invalidToken,
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list rules with offset",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
query: "offset=1",
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with invalid offset",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "offset=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
{
|
||||
desc: "list rules with limit",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
query: "limit=1",
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with invalid limit",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "limit=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
{
|
||||
desc: "list rules with invalid direction",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "dir=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrInvalidDirection,
|
||||
},
|
||||
{
|
||||
desc: "list rule with limit that is too big",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "limit=10000",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "list rules with input channel",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
query: "input_channel=input.channel",
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with duplicate input_channel",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "input_channel=1&input_channel=2",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list rules with status",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
query: "status=enabled",
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with invalid status",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "status=invalid",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
{
|
||||
desc: "list rules with duplicate status",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "status=enabled&status=disabled",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
{
|
||||
desc: "list rules with output channel",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
listRulesResponse: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: 1,
|
||||
},
|
||||
Rules: []re.Rule{rule},
|
||||
},
|
||||
query: "output_channel=output.channel",
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with duplicate output channel",
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
query: "output_channel=1&output_channel=2",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrInvalidQueryParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodGet,
|
||||
url: ts.URL + "/" + tc.domainID + "/rules?" + tc.query,
|
||||
contentType: contentType,
|
||||
token: tc.token,
|
||||
}
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}
|
||||
}
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
|
||||
svcCall := svc.On("ListRules", mock.Anything, tc.session, mock.Anything).Return(tc.listRulesResponse, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var bodyRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&bodyRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if bodyRes.Err != "" || bodyRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRulesEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
updateRuleReq := re.Rule{
|
||||
ID: rule.ID,
|
||||
Name: rule.Name,
|
||||
Logic: re.Script{
|
||||
Type: re.ScriptType(0),
|
||||
Value: "return `test` end",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"name": "test",
|
||||
},
|
||||
}
|
||||
|
||||
invalidReq := re.Rule{
|
||||
ID: rule.ID,
|
||||
Name: rule.Name,
|
||||
Metadata: map[string]any{
|
||||
"name": "test",
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
id string
|
||||
domainID string
|
||||
updateReq re.Rule
|
||||
contentType string
|
||||
session smqauthn.Session
|
||||
svcResp re.Rule
|
||||
svcErr error
|
||||
status int
|
||||
authnErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update rule successfully",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
id: rule.ID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: contentType,
|
||||
svcResp: rule,
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update rule with invalid token",
|
||||
token: invalidToken,
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: rule.ID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: contentType,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "update rule with empty token",
|
||||
token: "",
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: rule.ID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: contentType,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "update rule with empty domainID",
|
||||
token: validToken,
|
||||
id: rule.ID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "update rule with empty logic",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
id: rule.ID,
|
||||
updateReq: invalidReq,
|
||||
contentType: contentType,
|
||||
svcResp: rule,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrEmptyList,
|
||||
},
|
||||
{
|
||||
desc: "update rule with name that is too long",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
updateReq: re.Rule{
|
||||
ID: validID,
|
||||
Name: strings.Repeat("a", 1025),
|
||||
Logic: re.Script{
|
||||
Type: re.ScriptType(0),
|
||||
Value: "return `test` end",
|
||||
},
|
||||
},
|
||||
contentType: contentType,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "update rule with invalid content type",
|
||||
token: validToken,
|
||||
id: rule.ID,
|
||||
domainID: domainID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: "application/xml",
|
||||
svcResp: rule,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "update rule with service error",
|
||||
token: validToken,
|
||||
id: rule.ID,
|
||||
domainID: domainID,
|
||||
updateReq: updateRuleReq,
|
||||
contentType: contentType,
|
||||
svcResp: re.Rule{},
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
data := toJSON(tc.updateReq)
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodPatch,
|
||||
url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id),
|
||||
contentType: tc.contentType,
|
||||
token: tc.token,
|
||||
body: strings.NewReader(data),
|
||||
}
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}
|
||||
}
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
|
||||
svcCall := svc.On("UpdateRule", mock.Anything, tc.session, tc.updateReq).Return(tc.svcResp, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var errRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableRuleEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
id string
|
||||
domainID string
|
||||
session smqauthn.Session
|
||||
svcResp re.Rule
|
||||
svcErr error
|
||||
status int
|
||||
authnErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "enable rule successfully",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
svcResp: rule,
|
||||
svcErr: nil,
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with invalid token",
|
||||
token: invalidToken,
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with empty token",
|
||||
token: "",
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with empty domainID",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with service error",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
svcResp: re.Rule{},
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with empty id",
|
||||
token: validToken,
|
||||
id: "",
|
||||
domainID: domainID,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/%s/rules/%s/enable", ts.URL, tc.domainID, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}
|
||||
}
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
|
||||
svcCall := svc.On("EnableRule", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var errRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableRuleEndpoint(t *testing.T) {
|
||||
gs, svc, authn := newRuleEngineServer()
|
||||
defer gs.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
id string
|
||||
domainID string
|
||||
session smqauthn.Session
|
||||
svcResp re.Rule
|
||||
svcErr error
|
||||
status int
|
||||
authnErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "disable rule successfully",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
svcResp: rule,
|
||||
svcErr: nil,
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with invalid token",
|
||||
token: invalidToken,
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with empty token",
|
||||
token: "",
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with empty domainID",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with service error",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
svcResp: re.Rule{},
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with empty id",
|
||||
token: validToken,
|
||||
id: "",
|
||||
domainID: domainID,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
client: gs.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/%s/rules/%s/disable", gs.URL, tc.domainID, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}
|
||||
}
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
|
||||
svcCall := svc.On("DisableRule", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
var errRes respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&errRes)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRuleEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
id string
|
||||
domainID string
|
||||
session smqauthn.Session
|
||||
svcErr error
|
||||
status int
|
||||
authnErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "delete rule successfully",
|
||||
token: validToken,
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
svcErr: nil,
|
||||
status: http.StatusNoContent,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "delete rule with invalid token",
|
||||
token: invalidToken,
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
status: http.StatusUnauthorized,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "delete rule with empty token",
|
||||
token: "",
|
||||
session: smqauthn.Session{},
|
||||
domainID: domainID,
|
||||
id: validID,
|
||||
status: http.StatusUnauthorized,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "delete rule with empty domainID",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "delete rule with service error",
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
client: ts.Client(),
|
||||
method: http.MethodDelete,
|
||||
url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
if tc.token == validToken {
|
||||
tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}
|
||||
}
|
||||
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
|
||||
svcCall := svc.On("RemoveRule", mock.Anything, tc.session, tc.id).Return(tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type respBody struct {
|
||||
Err string `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Total uint64 `json:"total"`
|
||||
ID string `json:"id"`
|
||||
Status re.Status `json:"status"`
|
||||
}
|
||||
+9
-2
@@ -30,11 +30,10 @@ const (
|
||||
)
|
||||
|
||||
// MakeHandler creates an HTTP handler for the service endpoints.
|
||||
func MakeHandler(svc re.Service, authn mgauthn.Authentication, logger *slog.Logger, instanceID string) http.Handler {
|
||||
func MakeHandler(svc re.Service, authn mgauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
mux := chi.NewRouter()
|
||||
mux.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Route("/{domainID}/rules", func(r chi.Router) {
|
||||
@@ -112,6 +111,9 @@ func decodeViewRuleRequest(_ context.Context, r *http.Request) (interface{}, err
|
||||
}
|
||||
|
||||
func decodeUpdateRuleRequest(_ context.Context, r *http.Request) (interface{}, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
var rule re.Rule
|
||||
if err := json.NewDecoder(r.Body).Decode(&rule); err != nil {
|
||||
return nil, err
|
||||
@@ -148,6 +150,10 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (interface{}, er
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
st, err := re.ToStatus(s)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
@@ -159,6 +165,7 @@ func decodeListRulesRequest(_ context.Context, r *http.Request) (interface{}, er
|
||||
InputChannel: ic,
|
||||
OutputChannel: oc,
|
||||
Status: st,
|
||||
Dir: dir,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
+37
-10
@@ -5,12 +5,13 @@ package re
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
"github.com/absmach/supermq/consumers"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
mgjson "github.com/absmach/supermq/pkg/transformers/json"
|
||||
lua "github.com/yuin/gopher-lua"
|
||||
@@ -131,29 +132,47 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule,
|
||||
|
||||
rule, err := re.repo.AddRule(ctx, r)
|
||||
if err != nil {
|
||||
return Rule{}, err
|
||||
return Rule{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) ViewRule(ctx context.Context, session authn.Session, id string) (Rule, error) {
|
||||
return re.repo.ViewRule(ctx, id)
|
||||
rule, err := re.repo.ViewRule(ctx, id)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) {
|
||||
r.UpdatedAt = time.Now()
|
||||
r.UpdatedBy = session.UserID
|
||||
return re.repo.UpdateRule(ctx, r)
|
||||
rule, err := re.repo.UpdateRule(ctx, r)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) {
|
||||
pm.Domain = session.DomainID
|
||||
return re.repo.ListRules(ctx, pm)
|
||||
page, err := re.repo.ListRules(ctx, pm)
|
||||
if err != nil {
|
||||
return Page{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) error {
|
||||
return re.repo.RemoveRule(ctx, id)
|
||||
if err := re.repo.RemoveRule(ctx, id); err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) {
|
||||
@@ -161,7 +180,11 @@ func (re *re) EnableRule(ctx context.Context, session authn.Session, id string)
|
||||
if err != nil {
|
||||
return Rule{}, err
|
||||
}
|
||||
return re.repo.UpdateRuleStatus(ctx, id, status)
|
||||
rule, err := re.repo.UpdateRuleStatus(ctx, id, status)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) {
|
||||
@@ -169,7 +192,11 @@ func (re *re) DisableRule(ctx context.Context, session authn.Session, id string)
|
||||
if err != nil {
|
||||
return Rule{}, err
|
||||
}
|
||||
return re.repo.UpdateRuleStatus(ctx, id, status)
|
||||
rule, err := re.repo.UpdateRuleStatus(ctx, id, status)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (re *re) ConsumeAsync(ctx context.Context, msgs interface{}) {
|
||||
@@ -181,7 +208,7 @@ func (re *re) ConsumeAsync(ctx context.Context, msgs interface{}) {
|
||||
}
|
||||
page, err := re.repo.ListRules(ctx, pm)
|
||||
if err != nil {
|
||||
re.errors <- err
|
||||
re.errors <- errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
return
|
||||
}
|
||||
for _, r := range page.Rules {
|
||||
@@ -282,7 +309,7 @@ func (r Rule) shouldRun(startTime time.Time) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
t := r.Schedule.Time.Truncate(time.Minute)
|
||||
t := r.Schedule.Time.Truncate(time.Minute).UTC()
|
||||
startTimeOnly := time.Date(0, 1, 1, startTime.Hour(), startTime.Minute(), 0, 0, time.UTC)
|
||||
if t.Equal(startTimeOnly) {
|
||||
return true
|
||||
|
||||
+715
-22
@@ -11,34 +11,44 @@ import (
|
||||
|
||||
"github.com/0x6flab/namegenerator"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/re"
|
||||
"github.com/absmach/magistrala/re/mocks"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
"github.com/absmach/supermq/pkg/messaging"
|
||||
pubsubmocks "github.com/absmach/supermq/pkg/messaging/mocks"
|
||||
mgjson "github.com/absmach/supermq/pkg/transformers/json"
|
||||
"github.com/absmach/supermq/pkg/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
var (
|
||||
namegen = namegenerator.NewGenerator()
|
||||
rule = re.Rule{
|
||||
namegen = namegenerator.NewGenerator()
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
ruleName = namegen.Generate()
|
||||
ruleID = testsutil.GenerateUUID(&testing.T{})
|
||||
inputChannel = "test.channel"
|
||||
schedule = re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
}
|
||||
rule = re.Rule{
|
||||
ID: testsutil.GenerateUUID(&testing.T{}),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: "test.channel",
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour), // Started an hour ago
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
},
|
||||
Schedule: schedule,
|
||||
}
|
||||
futureRule = re.Rule{
|
||||
ID: testsutil.GenerateUUID(&testing.T{}),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: "test.channel",
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(24 * time.Hour),
|
||||
@@ -47,17 +57,648 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func newService(t *testing.T) (re.Service, *mocks.Repository, *mocks.Ticker) {
|
||||
func newService(t *testing.T) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *mocks.Ticker) {
|
||||
repo := new(mocks.Repository)
|
||||
mockTicker := new(mocks.Ticker)
|
||||
idProvider := uuid.NewMock()
|
||||
pubsub := pubsubmocks.NewPubSub(t)
|
||||
return re.NewService(repo, idProvider, pubsub, mockTicker), repo, mockTicker
|
||||
return re.NewService(repo, idProvider, pubsub, mockTicker), repo, pubsub, mockTicker
|
||||
}
|
||||
|
||||
func TestAddRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
ruleName := namegen.Generate()
|
||||
now := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
rule re.Rule
|
||||
res re.Rule
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Add rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Add rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
},
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("AddRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
res, err := svc.AddRule(context.Background(), tc.session, tc.rule)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, res.ID, "expected non-empty result in ID")
|
||||
assert.Equal(t, tc.rule.Name, res.Name)
|
||||
assert.Equal(t, tc.rule.Schedule, res.Schedule)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
|
||||
now := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
res re.Rule
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "view rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
res: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "view rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ViewRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
res, err := svc.ViewRule(context.Background(), tc.session, tc.id)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.res, res)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
|
||||
newName := namegen.Generate()
|
||||
now := time.Now().Add(time.Hour)
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
rule re.Rule
|
||||
res re.Rule
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: newName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
res: re.Rule{
|
||||
Name: newName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
UpdatedAt: now,
|
||||
UpdatedBy: userID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
rule: re.Rule{
|
||||
Name: ruleName,
|
||||
ID: ruleID,
|
||||
InputChannel: inputChannel,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
},
|
||||
Status: re.EnabledStatus,
|
||||
CreatedBy: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("UpdateRule", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
res, err := svc.UpdateRule(context.Background(), tc.session, tc.rule)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.res, res)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRules(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
numRules := 50
|
||||
now := time.Now().Add(time.Hour)
|
||||
var rules []re.Rule
|
||||
for i := 0; i < numRules; i++ {
|
||||
r := re.Rule{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
DomainID: domainID,
|
||||
Status: re.EnabledStatus,
|
||||
CreatedAt: now,
|
||||
CreatedBy: userID,
|
||||
Schedule: re.Schedule{
|
||||
Recurring: re.Daily,
|
||||
Time: now.Add(1 * time.Hour),
|
||||
RecurringPeriod: 1,
|
||||
StartDateTime: now.Add(-1 * time.Hour),
|
||||
},
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta re.PageMeta
|
||||
res re.Page
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list rules successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
pageMeta: re.PageMeta{},
|
||||
res: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: uint64(numRules),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
},
|
||||
Rules: rules[0:10],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules successfully with limit",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
pageMeta: re.PageMeta{
|
||||
Limit: 100,
|
||||
},
|
||||
res: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: uint64(numRules),
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
Rules: rules[0:numRules],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules successfully with offset",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
pageMeta: re.PageMeta{
|
||||
Offset: 20,
|
||||
Limit: 10,
|
||||
},
|
||||
res: re.Page{
|
||||
PageMeta: re.PageMeta{
|
||||
Total: uint64(numRules),
|
||||
Offset: 20,
|
||||
Limit: 10,
|
||||
},
|
||||
Rules: rules[20:30],
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list rules with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
pageMeta: re.PageMeta{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(tc.res, tc.err)
|
||||
res, err := svc.ListRules(context.Background(), tc.session, tc.pageMeta)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.res, res)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "remove rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "remove rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("RemoveRule", mock.Anything, mock.Anything).Return(tc.err)
|
||||
err := svc.RemoveRule(context.Background(), tc.session, tc.id)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
status re.Status
|
||||
res re.Rule
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "enable rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
status: re.EnabledStatus,
|
||||
res: re.Rule{
|
||||
ID: ruleID,
|
||||
Name: ruleName,
|
||||
DomainID: domainID,
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: schedule,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "enable rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
status: re.EnabledStatus,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("UpdateRuleStatus", context.Background(), tc.id, tc.status).Return(tc.res, tc.err)
|
||||
res, err := svc.EnableRule(context.Background(), tc.session, tc.id)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.res, res)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableRule(t *testing.T) {
|
||||
svc, repo, _, _ := newService(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
id string
|
||||
status re.Status
|
||||
res re.Rule
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "disable rule successfully",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
status: re.DisabledStatus,
|
||||
res: re.Rule{
|
||||
ID: ruleID,
|
||||
Name: ruleName,
|
||||
DomainID: domainID,
|
||||
InputChannel: inputChannel,
|
||||
Status: re.DisabledStatus,
|
||||
Schedule: schedule,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "disable rule with failed repo",
|
||||
session: authn.Session{
|
||||
UserID: userID,
|
||||
DomainID: domainID,
|
||||
},
|
||||
id: ruleID,
|
||||
status: re.DisabledStatus,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("UpdateRuleStatus", mock.Anything, tc.id, tc.status).Return(tc.res, tc.err)
|
||||
res, err := svc.DisableRule(context.Background(), tc.session, tc.id)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, tc.res, res)
|
||||
}
|
||||
defer repoCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeAsync(t *testing.T) {
|
||||
svc, repo, pubmocks, _ := newService(t)
|
||||
now := time.Now()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
message any
|
||||
pageMeta re.PageMeta
|
||||
page re.Page
|
||||
listErr error
|
||||
publishErr error
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
desc: "consume message with empty rules",
|
||||
message: &messaging.Message{
|
||||
Channel: inputChannel,
|
||||
Created: now.Unix(),
|
||||
},
|
||||
pageMeta: re.PageMeta{
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{},
|
||||
},
|
||||
listErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "consume message with rules",
|
||||
message: &messaging.Message{
|
||||
Channel: inputChannel,
|
||||
Created: now.Unix(),
|
||||
},
|
||||
pageMeta: re.PageMeta{
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{
|
||||
{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Logic: re.Script{
|
||||
Type: re.ScriptType(0),
|
||||
},
|
||||
OutputChannel: "output.channel",
|
||||
Schedule: schedule,
|
||||
},
|
||||
},
|
||||
},
|
||||
listErr: nil,
|
||||
},
|
||||
{
|
||||
desc: "consume message with unsupported message type",
|
||||
message: "unsupported message type",
|
||||
pageMeta: re.PageMeta{
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
},
|
||||
page: re.Page{},
|
||||
},
|
||||
{
|
||||
desc: "consume json message",
|
||||
message: mgjson.Message{},
|
||||
pageMeta: re.PageMeta{
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
},
|
||||
page: re.Page{},
|
||||
listErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
|
||||
repoCall := repo.On("ListRules", mock.Anything, tc.pageMeta).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) {
|
||||
if tc.listErr != nil {
|
||||
err = tc.listErr
|
||||
}
|
||||
})
|
||||
repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr)
|
||||
|
||||
svc.ConsumeAsync(ctx, tc.message)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err))
|
||||
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartScheduler(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Minute)
|
||||
svc, repo, ticker := newService(t)
|
||||
svc, repo, _, ticker := newService(t)
|
||||
|
||||
noRecurringPeriod := re.Rule{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
Recurring: re.None,
|
||||
RecurringPeriod: 0,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
weeklyRule := re.Rule{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
Recurring: re.Weekly,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
monthlyRule := re.Rule{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
Recurring: re.Monthly,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
pastRule := re.Rule{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: namegen.Generate(),
|
||||
InputChannel: inputChannel,
|
||||
Status: re.EnabledStatus,
|
||||
Schedule: re.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
Recurring: re.None,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -142,6 +783,62 @@ func TestStartScheduler(t *testing.T) {
|
||||
return context.WithCancel(context.Background())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "start scheduler successfully with no recurring period",
|
||||
err: context.Canceled,
|
||||
pageMeta: re.PageMeta{
|
||||
Status: re.EnabledStatus,
|
||||
ScheduledBefore: &now,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{noRecurringPeriod},
|
||||
},
|
||||
setupCtx: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithCancel(context.Background())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "start scheduler successfully with weekly schedule",
|
||||
err: context.Canceled,
|
||||
pageMeta: re.PageMeta{
|
||||
Status: re.EnabledStatus,
|
||||
ScheduledBefore: &now,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{weeklyRule},
|
||||
},
|
||||
setupCtx: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithCancel(context.Background())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "start scheduler successfully with monthly schedule",
|
||||
err: context.Canceled,
|
||||
pageMeta: re.PageMeta{
|
||||
Status: re.EnabledStatus,
|
||||
ScheduledBefore: &now,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{monthlyRule},
|
||||
},
|
||||
setupCtx: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithCancel(context.Background())
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "start scheduler successfully processes rules with past schedule",
|
||||
err: context.Canceled,
|
||||
pageMeta: re.PageMeta{
|
||||
Status: re.EnabledStatus,
|
||||
ScheduledBefore: &now,
|
||||
},
|
||||
page: re.Page{
|
||||
Rules: []re.Rule{pastRule},
|
||||
},
|
||||
setupCtx: func() (context.Context, context.CancelFunc) {
|
||||
return context.WithCancel(context.Background())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
@@ -161,20 +858,16 @@ func TestStartScheduler(t *testing.T) {
|
||||
switch tc.desc {
|
||||
case "start scheduler with canceled context":
|
||||
cancel()
|
||||
case "start scheduler successfully processes rules":
|
||||
tickChan <- time.Now()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
case "start scheduler with rule to be run in the future":
|
||||
tickChan <- time.Now()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
case "start scheduler with list error":
|
||||
tickChan <- time.Now()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := svc.Errors(); err != nil {
|
||||
cancel()
|
||||
}
|
||||
default:
|
||||
tickChan <- time.Now()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}
|
||||
|
||||
err := <-errc
|
||||
|
||||
Reference in New Issue
Block a user