From 257db27769254fdccf8ea59c5e371be8d00b113f Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Mon, 10 Nov 2025 20:03:10 +0300 Subject: [PATCH] MG-132 - Improve RE tests (#346) * initial implementation Signed-off-by: nyagamunene * add coverage for api tests Signed-off-by: nyagamunene * add coverage for api tests Signed-off-by: nyagamunene * add tests for handler Signed-off-by: nyagamunene * add tests for start schedular Signed-off-by: nyagamunene * fix failing linter Signed-off-by: nyagamunene * fix failing linter Signed-off-by: nyagamunene * fix failing linter Signed-off-by: nyagamunene * fix race condition Signed-off-by: nyagamunene * address comments Signed-off-by: nyagamunene * fix addrule test Signed-off-by: nyagamunene * fix list rule method Signed-off-by: nyagamunene * use sorting for the slice Signed-off-by: nyagamunene * fetch supermq Signed-off-by: nyagamunene --------- Signed-off-by: nyagamunene --- docker/supermq-docker/Dockerfile | 2 +- docker/supermq-docker/Dockerfile.dev | 2 +- docker/supermq-docker/docker-compose.yaml | 54 +- re/api/endpoints.go | 2 +- re/api/endpoints_test.go | 163 ++++ re/api/requests.go | 2 +- re/handlers.go | 6 +- re/postgres/repository_test.go | 998 ++++++++++++++++++++++ re/postgres/setup_test.go | 93 ++ re/service_test.go | 902 ++++++++++++++++++- re/status_test.go | 205 +++++ 11 files changed, 2380 insertions(+), 49 deletions(-) create mode 100644 re/postgres/repository_test.go create mode 100644 re/postgres/setup_test.go create mode 100644 re/status_test.go diff --git a/docker/supermq-docker/Dockerfile b/docker/supermq-docker/Dockerfile index 175f3a29a..f0c2fd7ce 100644 --- a/docker/supermq-docker/Dockerfile +++ b/docker/supermq-docker/Dockerfile @@ -1,7 +1,7 @@ # Copyright (c) Abstract Machines # SPDX-License-Identifier: Apache-2.0 -FROM golang:1.25.3-alpine AS builder +FROM golang:1.25.3-alpine3.22 AS builder ARG SVC ARG GOARCH ARG GOARM diff --git a/docker/supermq-docker/Dockerfile.dev b/docker/supermq-docker/Dockerfile.dev index 7d55569c2..7016e28c2 100644 --- a/docker/supermq-docker/Dockerfile.dev +++ b/docker/supermq-docker/Dockerfile.dev @@ -4,5 +4,5 @@ FROM scratch ARG SVC COPY $SVC /exe -COPY --from=alpine:latest /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt +COPY --from=alpine:3.22 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt ENTRYPOINT ["/exe"] diff --git a/docker/supermq-docker/docker-compose.yaml b/docker/supermq-docker/docker-compose.yaml index 3a04f7c4f..0406b67ef 100644 --- a/docker/supermq-docker/docker-compose.yaml +++ b/docker/supermq-docker/docker-compose.yaml @@ -26,7 +26,7 @@ volumes: services: spicedb: - image: "authzed/spicedb:v1.37.0" + image: docker.io/authzed/spicedb:v1.37.0 container_name: supermq-spicedb command: "serve" restart: "always" @@ -44,7 +44,7 @@ services: - spicedb-migrate spicedb-migrate: - image: "authzed/spicedb:v1.37.0" + image: docker.io/authzed/spicedb:v1.37.0 container_name: supermq-spicedb-migrate command: "migrate head" restart: "on-failure" @@ -57,7 +57,7 @@ services: - spicedb-db spicedb-db: - image: "postgres:16.2-alpine" + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-spicedb-db networks: - supermq-base-net @@ -72,7 +72,7 @@ services: command: ["postgres", "-c", "track_commit_timestamp=on"] auth-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-auth-db restart: on-failure ports: @@ -87,7 +87,7 @@ services: - supermq-auth-db-volume:/var/lib/postgresql/data auth-redis: - image: redis:7.2.4-alpine + image: docker.io/redis:8.2.2-alpine3.22 container_name: supermq-auth-redis restart: on-failure networks: @@ -96,7 +96,7 @@ services: - supermq-auth-redis-volume:/data auth: - image: supermq/auth:${SMQ_RELEASE_TAG} + image: docker.io/supermq/auth:${SMQ_RELEASE_TAG} container_name: supermq-auth depends_on: - auth-db @@ -189,7 +189,7 @@ services: create_host_path: true domains-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-domains-db restart: on-failure ports: @@ -204,7 +204,7 @@ services: - supermq-domains-db-volume:/var/lib/postgresql/data domains-redis: - image: redis:7.2.4-alpine + image: docker.io/redis:8.2.2-alpine3.22 container_name: supermq-domains-redis restart: on-failure networks: @@ -213,7 +213,7 @@ services: - supermq-domains-redis-volume:/data domains: - image: supermq/domains:${SMQ_RELEASE_TAG} + image: docker.io/supermq/domains:${SMQ_RELEASE_TAG} container_name: supermq-domains depends_on: - domains-db @@ -380,7 +380,7 @@ services: create_host_path: true nginx: - image: nginx:1.25.4-alpine + image: docker.io/nginx:1.29.2-alpine3.22 container_name: supermq-nginx restart: on-failure volumes: @@ -423,7 +423,7 @@ services: hard: 65536 clients-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-clients-db restart: on-failure command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" @@ -440,7 +440,7 @@ services: - supermq-clients-db-volume:/var/lib/postgresql/data clients-redis: - image: redis:7.2.4-alpine + image: docker.io/redis:8.2.2-alpine3.22 container_name: supermq-clients-redis restart: on-failure networks: @@ -449,7 +449,7 @@ services: - supermq-clients-redis-volume:/data clients: - image: supermq/clients:${SMQ_RELEASE_TAG} + image: docker.io/supermq/clients:${SMQ_RELEASE_TAG} container_name: supermq-clients depends_on: - clients-db @@ -616,7 +616,7 @@ services: create_host_path: true channels-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-channels-db restart: on-failure command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" @@ -633,7 +633,7 @@ services: - supermq-channels-db-volume:/var/lib/postgresql/data channels-redis: - image: redis:7.2.4-alpine + image: docker.io/redis:8.2.2-alpine3.22 container_name: supermq-channels-redis restart: on-failure networks: @@ -642,7 +642,7 @@ services: - supermq-channels-redis-volume:/data channels: - image: supermq/channels:${SMQ_RELEASE_TAG} + image: docker.io/supermq/channels:${SMQ_RELEASE_TAG} container_name: supermq-channels depends_on: - channels-db @@ -807,7 +807,7 @@ services: create_host_path: true users-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-users-db restart: on-failure command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" @@ -824,7 +824,7 @@ services: - supermq-users-db-volume:/var/lib/postgresql/data users: - image: supermq/users:${SMQ_RELEASE_TAG} + image: docker.io/supermq/users:${SMQ_RELEASE_TAG} container_name: supermq-users depends_on: - users-db @@ -933,7 +933,7 @@ services: create_host_path: true groups-db: - image: postgres:16.2-alpine + image: docker.io/postgres:18.0-alpine3.22 container_name: supermq-groups-db restart: on-failure command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" @@ -950,7 +950,7 @@ services: - supermq-groups-db-volume:/var/lib/postgresql/data groups: - image: supermq/groups:${SMQ_RELEASE_TAG} + image: docker.io/supermq/groups:${SMQ_RELEASE_TAG} container_name: supermq-groups depends_on: - groups-db @@ -1113,7 +1113,7 @@ services: create_host_path: true jaeger: - image: jaegertracing/all-in-one:1.66.0 + image: docker.io/jaegertracing/all-in-one:1.74.0 container_name: supermq-jaeger environment: COLLECTOR_OTLP_ENABLED: ${SMQ_JAEGER_COLLECTOR_OTLP_ENABLED} @@ -1125,7 +1125,7 @@ services: - supermq-base-net mqtt-adapter: - image: supermq/mqtt:${SMQ_RELEASE_TAG} + image: docker.io/supermq/mqtt:${SMQ_RELEASE_TAG} container_name: supermq-mqtt depends_on: - clients @@ -1224,7 +1224,7 @@ services: create_host_path: true http-adapter: - image: supermq/http:${SMQ_RELEASE_TAG} + image: docker.io/supermq/http:${SMQ_RELEASE_TAG} container_name: supermq-http depends_on: - clients @@ -1336,7 +1336,7 @@ services: create_host_path: true coap-adapter: - image: supermq/coap:${SMQ_RELEASE_TAG} + image: docker.io/supermq/coap:${SMQ_RELEASE_TAG} container_name: supermq-coap depends_on: - clients @@ -1454,7 +1454,7 @@ services: create_host_path: true ws-adapter: - image: supermq/ws:${SMQ_RELEASE_TAG} + image: docker.io/supermq/ws:${SMQ_RELEASE_TAG} container_name: supermq-ws depends_on: - clients @@ -1566,7 +1566,7 @@ services: create_host_path: true rabbitmq: - image: rabbitmq:4.0.5-management-alpine + image: docker.io/rabbitmq:4.1.4-management-alpine container_name: supermq-rabbitmq restart: on-failure environment: @@ -1587,7 +1587,7 @@ services: - supermq-base-net nats: - image: nats:2.10.25-alpine + image: docker.io/nats:2.12.0-alpine3.22 container_name: supermq-nats restart: on-failure command: "--config=/etc/nats/nats.conf" diff --git a/re/api/endpoints.go b/re/api/endpoints.go index fa8c3b356..af74680e4 100644 --- a/re/api/endpoints.go +++ b/re/api/endpoints.go @@ -134,7 +134,7 @@ func listRulesEndpoint(s re.Service) endpoint.Endpoint { } page, err := s.ListRules(ctx, session, req.PageMeta) if err != nil { - return rulesPageRes{}, nil + return rulesPageRes{}, err } ret := rulesPageRes{ Page: page, diff --git a/re/api/endpoints_test.go b/re/api/endpoints_test.go index 39f127f85..e6d9f0a35 100644 --- a/re/api/endpoints_test.go +++ b/re/api/endpoints_test.go @@ -443,6 +443,14 @@ func TestListRulesEndpoint(t *testing.T) { status: http.StatusBadRequest, err: apiutil.ErrInvalidDirection, }, + { + desc: "list rules with invalid order", + domainID: domainID, + token: validToken, + query: "order=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidOrder, + }, { desc: "list rule with limit that is too big", domainID: domainID, @@ -507,6 +515,14 @@ func TestListRulesEndpoint(t *testing.T) { status: http.StatusBadRequest, err: apiutil.ErrInvalidQueryParams, }, + { + desc: "list rules with service error", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{}, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, } for _, tc := range cases { @@ -827,6 +843,153 @@ func TestUpdateRuleTagsEndpoint(t *testing.T) { } } +func TestUpdateRuleScheduleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + updateScheduleReq := pkgSch.Schedule{ + StartDateTime: future, + Time: future.Add(2 * time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + } + + ruleWithSchedule := rule + ruleWithSchedule.Schedule = updateScheduleReq + + cases := []struct { + desc string + token string + id string + domainID string + schedule pkgSch.Schedule + contentType string + session smqauthn.Session + svcResp re.Rule + svcErr error + status int + authnErr error + err error + }{ + { + desc: "update rule schedule successfully", + token: validToken, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + svcResp: ruleWithSchedule, + status: http.StatusOK, + err: nil, + }, + { + desc: "update rule schedule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update rule schedule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update rule schedule with empty domainID", + token: validToken, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update rule schedule with invalid content type", + token: validToken, + id: validID, + domainID: domainID, + schedule: updateScheduleReq, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update rule schedule with start_datetime in past", + token: validToken, + id: validID, + domainID: domainID, + schedule: pkgSch.Schedule{ + StartDateTime: past, + Time: future, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "update rule schedule with service error", + token: validToken, + id: validID, + domainID: domainID, + schedule: updateScheduleReq, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "update rule schedule with empty id", + token: validToken, + id: "", + domainID: domainID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(map[string]any{ + "schedule": tc.schedule, + }) + + req := testRequest{ + client: ts.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/rules/%s/schedule", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateRuleSchedule", mock.Anything, mock.Anything, mock.Anything).Return(tc.svcResp, tc.svcErr) + + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + + svcCall.Unset() + authCall.Unset() + }) + } +} + func TestEnableRuleEndpoint(t *testing.T) { ts, svc, authn := newRuleEngineServer() defer ts.Close() diff --git a/re/api/requests.go b/re/api/requests.go index cdd7a9c6a..cb3c9356f 100644 --- a/re/api/requests.go +++ b/re/api/requests.go @@ -56,7 +56,7 @@ func (req listRulesReq) validate() error { switch req.Order { case "", api.NameKey, api.CreatedAtOrder, api.UpdatedAtOrder: default: - return apiutil.ErrInvalidOrder + return errors.Wrap(apiutil.ErrInvalidOrder, apiutil.ErrValidation) } if req.Dir != api.AscDir && req.Dir != api.DescDir { diff --git a/re/handlers.go b/re/handlers.go index ae4274da6..851890feb 100644 --- a/re/handlers.go +++ b/re/handlers.go @@ -142,7 +142,7 @@ func (re *re) StartScheduler(ctx context.Context) error { } for _, r := range page.Rules { - go func(rule Rule) { + go func(rule Rule, dueTime time.Time) { if _, err := re.repo.UpdateRuleDue(ctx, rule.ID, rule.Schedule.NextDue()); err != nil { re.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update rule: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}} return @@ -153,10 +153,10 @@ func (re *re) StartScheduler(ctx context.Context) error { Channel: rule.InputChannel, Subtopic: rule.InputTopic, Protocol: protocol, - Created: due.Unix(), + Created: dueTime.Unix(), } re.runInfo <- re.process(ctx, rule, msg) - }(r) + }(r, due) } // Reset due, it will reset in the page meta as well. due = time.Now().UTC() diff --git a/re/postgres/repository_test.go b/re/postgres/repository_test.go new file mode 100644 index 000000000..61c2cded4 --- /dev/null +++ b/re/postgres/repository_test.go @@ -0,0 +1,998 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/magistrala/pkg/schedule" + "github.com/absmach/magistrala/re" + "github.com/absmach/magistrala/re/outputs" + "github.com/absmach/magistrala/re/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" +) + +const ( + ascDir = "asc" + descDir = "desc" + nameOrder = "name" + createdAtOrder = "created_at" + updatedAtOrder = "updated_at" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func TestAddRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Tags: []string{"test", "rule"}, + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Outputs: re.Outputs{ + &outputs.Alarm{}, + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + + scheduleName := namegen.Generate() + scheduleDomain := generateUUID(t) + scheduleChannel := generateUUID(t) + scheduleCreatedBy := generateUUID(t) + scheduleCreatedAt := time.Now().UTC().Truncate(time.Microsecond) + scheduleUpdatedBy := generateUUID(t) + scheduleUpdatedAt := time.Now().UTC().Truncate(time.Microsecond) + scheduleStartTime := time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond) + scheduleTime := time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond) + + scheduleRule := re.Rule{ + ID: generateUUID(t), + Name: scheduleName, + DomainID: scheduleDomain, + InputChannel: scheduleChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 50", + }, + Schedule: schedule.Schedule{ + StartDateTime: scheduleStartTime, + Time: scheduleTime, + Recurring: schedule.Daily, + RecurringPeriod: 1, + }, + Status: re.EnabledStatus, + CreatedAt: scheduleCreatedAt, + CreatedBy: scheduleCreatedBy, + UpdatedAt: scheduleUpdatedAt, + UpdatedBy: scheduleUpdatedBy, + Metadata: re.Metadata{}, + } + + outputsName := namegen.Generate() + outputsDomain := generateUUID(t) + outputsChannel := generateUUID(t) + outputsCreatedBy := generateUUID(t) + outputsCreatedAt := time.Now().UTC().Truncate(time.Microsecond) + outputsUpdatedBy := generateUUID(t) + outputsUpdatedAt := time.Now().UTC().Truncate(time.Microsecond) + outputsRuleID := generateUUID(t) + + outputsRule := re.Rule{ + ID: outputsRuleID, + Name: outputsName, + DomainID: outputsDomain, + InputChannel: outputsChannel, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: generateUUID(t), + Topic: "alerts", + }, + &outputs.SenML{}, + }, + Status: re.EnabledStatus, + CreatedAt: outputsCreatedAt, + CreatedBy: outputsCreatedBy, + UpdatedAt: outputsUpdatedAt, + UpdatedBy: outputsUpdatedBy, + Metadata: re.Metadata{}, + } + + cases := []struct { + desc string + rule re.Rule + resp re.Rule + err error + }{ + { + desc: "valid rule", + rule: rule, + resp: rule, + err: nil, + }, + { + desc: "duplicate rule", + rule: rule, + resp: re.Rule{}, + err: repoerr.ErrConflict, + }, + + { + desc: "rule with schedule", + rule: scheduleRule, + resp: scheduleRule, + err: nil, + }, + { + desc: "rule with outputs", + rule: outputsRule, + resp: outputsRule, + err: nil, + }, + { + desc: "invalid metadata", + rule: re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Metadata: map[string]any{ + "key": make(chan int), + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + addedRule, err := repo.AddRule(context.Background(), tc.rule) + if err == nil { + tc.resp.ID = addedRule.ID + assert.Equal(t, tc.resp, addedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, addedRule)) + } + }) + } +} + +func TestViewRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + resp re.Rule + err error + }{ + { + desc: "valid rule", + id: rule.ID, + resp: rule, + err: nil, + }, + { + desc: "non existing rule", + id: generateUUID(t), + resp: re.Rule{}, + err: repoerr.ErrViewEntity, + }, + { + desc: "empty id", + id: "", + resp: re.Rule{}, + err: repoerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + retrievedRule, err := repo.ViewRule(context.Background(), tc.id) + assert.Equal(t, tc.resp, retrievedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, retrievedRule)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestUpdateRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newInputChannel := generateUUID(t) + newUpdatedBy := generateUUID(t) + + cases := []struct { + desc string + rule re.Rule + resp re.Rule + err error + }{ + { + desc: "valid rule update", + rule: re.Rule{ + ID: rule.ID, + Name: "updated-name", + InputChannel: newInputChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 30", + }, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: newUpdatedBy, + Metadata: map[string]any{ + "updated": "metadata", + }, + }, + resp: re.Rule{ + ID: rule.ID, + Name: "updated-name", + DomainID: rule.DomainID, + InputChannel: newInputChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 30", + }, + Status: rule.Status, + CreatedAt: rule.CreatedAt, + CreatedBy: rule.CreatedBy, + UpdatedAt: time.Time{}, + UpdatedBy: newUpdatedBy, + Metadata: map[string]any{ + "updated": "metadata", + }, + }, + err: nil, + }, + { + desc: "update non-existing rule", + rule: re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + InputChannel: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrNotFound, + }, + { + desc: "update with invalid metadata", + rule: re.Rule{ + ID: rule.ID, + InputChannel: generateUUID(t), + Metadata: map[string]any{ + "key": make(chan int), + }, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRule(context.Background(), tc.rule) + if tc.err == nil { + tc.resp.UpdatedAt = updatedRule.UpdatedAt + } + assert.Equal(t, tc.resp, updatedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, updatedRule)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestUpdateRuleStatus(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + rule re.Rule + status re.Status + err error + }{ + { + desc: "disable rule", + rule: re.Rule{ + ID: rule.ID, + Status: re.DisabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + status: re.DisabledStatus, + err: nil, + }, + { + desc: "enable rule", + rule: re.Rule{ + ID: rule.ID, + Status: re.EnabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + status: re.EnabledStatus, + err: nil, + }, + { + desc: "update non-existing rule status", + rule: re.Rule{ + ID: generateUUID(t), + Status: re.DisabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleStatus(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.status, updatedRule.Status, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.status, updatedRule.Status)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleTags(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Tags: []string{"tag1", "tag2"}, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + rule re.Rule + tags []string + err error + }{ + { + desc: "update tags", + rule: re.Rule{ + ID: rule.ID, + Tags: []string{"newtag1", "newtag2", "newtag3"}, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + tags: []string{"newtag1", "newtag2", "newtag3"}, + err: nil, + }, + { + desc: "update non-existing rule tags", + rule: re.Rule{ + ID: generateUUID(t), + Tags: []string{"tag"}, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleTags(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.tags, updatedRule.Tags, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.tags, updatedRule.Tags)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleSchedule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newSchedule := schedule.Schedule{ + StartDateTime: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond), + Time: time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond), + Recurring: schedule.Weekly, + RecurringPeriod: 2, + } + + cases := []struct { + desc string + rule re.Rule + schedule schedule.Schedule + err error + }{ + { + desc: "update schedule", + rule: re.Rule{ + ID: rule.ID, + Schedule: newSchedule, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + schedule: newSchedule, + err: nil, + }, + { + desc: "update non-existing rule schedule", + rule: re.Rule{ + ID: generateUUID(t), + Schedule: newSchedule, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleSchedule(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.schedule.Recurring, updatedRule.Schedule.Recurring, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.Recurring, updatedRule.Schedule.Recurring)) + assert.Equal(t, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleDue(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Schedule: schedule.Schedule{ + Time: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond), + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newDue := time.Now().UTC().Add(3 * time.Hour).Truncate(time.Microsecond) + + cases := []struct { + desc string + id string + due time.Time + err error + }{ + { + desc: "update due time", + id: rule.ID, + due: newDue, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleDue(context.Background(), tc.id, tc.due) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.id, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.id, updatedRule.ID)) + assert.True(t, updatedRule.Schedule.Time.Sub(tc.due) < time.Second, fmt.Sprintf("%s: expected due time close to %v got %v\n", tc.desc, tc.due, updatedRule.Schedule.Time)) + } + }) + } +} + +func TestListRules(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + channelID := generateUUID(t) + items := make([]re.Rule, 100) + + for i := range 100 { + items[i] = re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: channelID, + Tags: []string{fmt.Sprintf("tag%d", i%10)}, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + if i%2 == 0 { + items[i].Status = re.DisabledStatus + } + if i%3 == 0 { + items[i].Schedule = schedule.Schedule{ + Time: time.Now().UTC().Add(time.Duration(i) * time.Hour), + Recurring: schedule.Daily, + } + } + rule, err := repo.AddRule(context.Background(), items[i]) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + items[i].ID = rule.ID + } + + cases := []struct { + desc string + pm re.PageMeta + count int + err error + }{ + { + desc: "list first page", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + { + desc: "list with offset", + pm: re.PageMeta{ + Offset: 10, + Limit: 20, + Status: re.AllStatus, + }, + count: 20, + err: nil, + }, + { + desc: "list by domain", + pm: re.PageMeta{ + Domain: domainID, + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list by channel", + pm: re.PageMeta{ + InputChannel: channelID, + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list enabled rules", + pm: re.PageMeta{ + Status: re.EnabledStatus, + Offset: 0, + Limit: 200, + }, + count: 50, + err: nil, + }, + { + desc: "list disabled rules", + pm: re.PageMeta{ + Status: re.DisabledStatus, + Offset: 0, + Limit: 200, + }, + count: 50, + err: nil, + }, + { + desc: "list by tag", + pm: re.PageMeta{ + Tag: "tag1", + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + { + desc: "list with zero limit returns all", + pm: re.PageMeta{ + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list non-existing domain", + pm: re.PageMeta{ + Domain: generateUUID(t), + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list ordered by name ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: nameOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by name descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: nameOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by created_at ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by created_at descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by updated_at ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: updatedAtOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by updated_at descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: updatedAtOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list with default order (updated_at desc)", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListRules(context.Background(), tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules))) + if len(page.Rules) > 1 { + switch tc.pm.Order { + case nameOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name <= page.Rules[j].Name + }), "Expected names to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name >= page.Rules[j].Name + }), "Expected names to be sorted descending") + } + case createdAtOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.Before(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted descending") + } + case updatedAtOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].UpdatedAt.Before(page.Rules[j].UpdatedAt) + }), "Expected updated_at to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].UpdatedAt.After(page.Rules[j].UpdatedAt) + }), "Expected updated_at to be sorted descending") + } + } + } + }) + } +} + +func TestRemoveRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove existing rule", + id: rule.ID, + err: nil, + }, + { + desc: "remove non-existing rule", + id: generateUUID(t), + err: repoerr.ErrNotFound, + }, + { + desc: "remove already removed rule", + id: rule.ID, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.RemoveRule(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func generateUUID(t *testing.T) string { + ulid, err := idProvider.ID() + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + return ulid +} diff --git a/re/postgres/setup_test.go b/re/postgres/setup_test.go new file mode 100644 index 000000000..0ee52b2a1 --- /dev/null +++ b/re/postgres/setup_test.go @@ -0,0 +1,93 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + repostgres "github.com/absmach/magistrala/re/postgres" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = postgres.Setup(dbConfig, *repostgres.Migration()); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/re/service_test.go b/re/service_test.go index b5fb8c3de..f5d949368 100644 --- a/re/service_test.go +++ b/re/service_test.go @@ -6,6 +6,7 @@ package re_test import ( "context" "fmt" + "log/slog" "testing" "time" @@ -30,6 +31,17 @@ import ( "github.com/stretchr/testify/mock" ) +// unknownOutput is a mock output type that doesn't match any known output type. +type unknownOutput struct{} + +func (u *unknownOutput) Run(ctx context.Context, msg *messaging.Message, val any) error { + return nil +} + +func (u *unknownOutput) MarshalJSON() ([]byte, error) { + return []byte(`{"type": "unknown"}`), nil +} + var ( namegen = namegenerator.NewGenerator() userID = testsutil.GenerateUUID(&testing.T{}) @@ -47,18 +59,19 @@ var ( } ) -func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker) { +func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer) { repo := new(mocks.Repository) mockTicker := new(tmocks.Ticker) idProvider := uuid.NewMock() pubsub := pubsubmocks.NewPubSub(t) readersSvc := new(readmocks.ReadersServiceClient) e := new(emocks.Emailer) - return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker + return re.NewService(repo, runInfo, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc), repo, pubsub, mockTicker, e } func TestAddRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) ruleName := namegen.Generate() now := time.Now().Add(time.Hour) cases := []struct { @@ -115,6 +128,38 @@ func TestAddRule(t *testing.T) { }, err: repoerr.ErrCreateEntity, }, + { + desc: "Add rule with non-zero StartDateTime", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + Time: now.Add(2 * time.Hour), + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + Time: now.Add(2 * time.Hour), + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: nil, + }, } for _, tc := range cases { @@ -133,7 +178,8 @@ func TestAddRule(t *testing.T) { } func TestViewRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) now := time.Now().Add(time.Hour) cases := []struct { @@ -191,7 +237,8 @@ func TestViewRule(t *testing.T) { } func TestUpdateRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) newName := namegen.Generate() now := time.Now().Add(time.Hour) @@ -276,7 +323,8 @@ func TestUpdateRule(t *testing.T) { } func TestUpdateRuleTags(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) cases := []struct { desc string @@ -331,8 +379,77 @@ func TestUpdateRuleTags(t *testing.T) { } } +func TestUpdateRuleSchedule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + now := time.Now().UTC() + future := now.Add(2 * time.Hour) + newSchedule := pkgSch.Schedule{ + StartDateTime: future, + Time: future.Add(time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + } + + cases := []struct { + desc string + session authn.Session + updateReq re.Rule + repoResp re.Rule + repoErr error + err error + }{ + { + desc: "update rule schedule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + }, + repoResp: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + UpdatedAt: now, + UpdatedBy: userID, + }, + }, + { + desc: "update rule schedule with repo error", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + }, + repoErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRuleSchedule", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr) + got, err := svc.UpdateRuleSchedule(context.Background(), tc.session, tc.updateReq) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.repoResp, got) + ok := repo.AssertCalled(t, "UpdateRuleSchedule", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("UpdateRuleSchedule was not called on %s", tc.desc)) + } + repoCall.Unset() + }) + } +} + func TestListRules(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) numRules := 50 now := time.Now().Add(time.Hour) var rules []re.Rule @@ -354,6 +471,19 @@ func TestListRules(t *testing.T) { rules = append(rules, r) } + goRule := re.Rule{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: re.EnabledStatus, + CreatedAt: now, + CreatedBy: userID, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + } + cases := []struct { desc string session authn.Session @@ -376,6 +506,21 @@ func TestListRules(t *testing.T) { }, err: nil, }, + { + desc: "list rules with go type", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{}, + res: re.Page{ + Total: 1, + Offset: 0, + Limit: 10, + Rules: []re.Rule{goRule}, + }, + err: nil, + }, { desc: "list rules successfully with limit", session: authn.Session{ @@ -437,7 +582,8 @@ func TestListRules(t *testing.T) { } func TestRemoveRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) cases := []struct { desc string @@ -477,7 +623,8 @@ func TestRemoveRule(t *testing.T) { } func TestEnableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) now := time.Now() @@ -536,7 +683,8 @@ func TestEnableRule(t *testing.T) { } func TestDisableRule(t *testing.T) { - svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + // nolint:dogsled + svc, repo, _, _, _ := newService(t, make(chan pkglog.RunInfo)) now := time.Now() @@ -595,7 +743,7 @@ func TestDisableRule(t *testing.T) { } func TestHandle(t *testing.T) { - svc, repo, pubmocks, _ := newService(t, make(chan pkglog.RunInfo)) + svc, repo, pubmocks, _, emailer := newService(t, make(chan pkglog.RunInfo)) now := time.Now() scheduled := false @@ -619,7 +767,239 @@ func TestHandle(t *testing.T) { listErr: nil, }, { - desc: "consume message with rules", + desc: "consume message with Lua script returning true", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return message.payload", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script returning false", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return false", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script with no outputs", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return message.payload", + }, + Outputs: re.Outputs{}, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script returning nil", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return nil", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script with invalid syntax", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "invalid lua syntax {{{", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Alarm output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 30.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return {severity = 2, description = "High temperature"}`, + }, + Outputs: re.Outputs{ + &outputs.Alarm{ + RuleID: testsutil.GenerateUUID(t), + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and SenML output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return {bn = "sensor1", n = "temperature", v = 25.5}`, + }, + Outputs: re.Outputs{ + &outputs.SenML{}, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Email output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Email{ + To: []string{"test@example.com"}, + Subject: "Temperature Alert", + Content: "Temperature: {{.Result}}", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with rules using GoType", message: &messaging.Message{ Channel: inputChannel, Created: now.Unix(), @@ -632,7 +1012,8 @@ func TestHandle(t *testing.T) { InputChannel: inputChannel, Status: re.EnabledStatus, Logic: re.Script{ - Type: re.ScriptType(0), + Type: re.GoType, + Value: "func() bool { return true }", }, Outputs: re.Outputs{ &outputs.ChannelPublisher{ @@ -646,6 +1027,333 @@ func TestHandle(t *testing.T) { }, listErr: nil, }, + { + desc: "consume message with GoType logic returning false", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return false }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType invalid logic value", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "invalid go code {{{", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType missing logicFunction", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func someOtherFunc() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType invalid function signature", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "var logicFunction = 42", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType function logicFunction properly named", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func logicFunction() any { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType returning non-bool", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() any { return \"not a bool\" }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType and JSON payload", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25, "humidity": 60}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType and invalid JSON payload", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`invalid json {{{`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Postgres output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5, "humidity": 60}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Postgres{ + Host: "localhost", + Port: 5432, + User: "test", + Password: "test", + Database: "testdb", + Table: "sensor_data", + Mapping: `{"temperature": {{.Result.temperature}}, "humidity": {{.Result.humidity}}}`, + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Slack output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Slack{ + Token: "xoxb-test-token", + ChannelID: "C12345678", + Message: `{"text": "Temperature: {{.Result.temperature}}"}`, + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and unknown output type", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &unknownOutput{}, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, } for _, tc := range cases { @@ -657,15 +1365,19 @@ func TestHandle(t *testing.T) { err = tc.listErr } }) - repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr) + repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr).Maybe() + repoCall2 := emailer.On("SendEmailNotification", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() err = svc.Handle(tc.message) assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) + assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() }) } } @@ -673,7 +1385,7 @@ func TestHandle(t *testing.T) { func TestStartScheduler(t *testing.T) { now := time.Now().Truncate(time.Minute) ri := make(chan pkglog.RunInfo) - svc, repo, _, ticker := newService(t, ri) + svc, repo, _, ticker, _ := newService(t, ri) ctxCases := []struct { desc string @@ -742,4 +1454,164 @@ func TestStartScheduler(t *testing.T) { tickCall1.Unset() }) } + + schedulerCases := []struct { + desc string + rules []re.Rule + listErr error + updateDueErr error + expectedRunInfo int + }{ + { + desc: "start scheduler with successful rule processing", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + }, + listErr: nil, + updateDueErr: nil, + expectedRunInfo: 1, + }, + { + desc: "start scheduler with multiple rules", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + }, + }, + listErr: nil, + updateDueErr: nil, + expectedRunInfo: 2, + }, + { + desc: "start scheduler with list rules error", + rules: []re.Rule{}, + listErr: repoerr.ErrViewEntity, + updateDueErr: nil, + expectedRunInfo: 1, + }, + { + desc: "start scheduler with update due error", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + }, + listErr: nil, + updateDueErr: repoerr.ErrUpdateEntity, + expectedRunInfo: 1, + }, + } + + for _, tc := range schedulerCases { + t.Run(tc.desc, func(t *testing.T) { + page := re.Page{ + Rules: tc.rules, + Total: uint64(len(tc.rules)), + } + + repoCall := repo.On("ListRules", mock.Anything, mock.Anything).Return(page, tc.listErr) + repoCall2 := repo.On("UpdateRuleDue", mock.Anything, mock.Anything, mock.Anything).Return(re.Rule{}, tc.updateDueErr) + tickChan := make(chan time.Time, 1) + tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan)) + tickCall1 := ticker.On("Stop").Return() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = svc.StartScheduler(ctx) + }() + + tickChan <- now + + collected := 0 + timeout := time.After(500 * time.Millisecond) + for collected < tc.expectedRunInfo { + select { + case info := <-ri: + collected++ + if tc.listErr != nil { + assert.Equal(t, slog.LevelError, info.Level) + assert.Contains(t, info.Message, "failed to list rules") + } else if tc.updateDueErr != nil { + assert.Equal(t, slog.LevelError, info.Level) + assert.Contains(t, info.Message, "failed to update rule") + } else { + assert.True(t, info.Level == slog.LevelInfo || info.Level == slog.LevelWarn || info.Level == slog.LevelError) + } + case <-timeout: + t.Fatalf("timeout waiting for runInfo messages, expected %d got %d", tc.expectedRunInfo, collected) + } + } + + cancel() + time.Sleep(50 * time.Millisecond) + + repoCall.Unset() + repoCall2.Unset() + tickCall.Unset() + tickCall1.Unset() + }) + } } diff --git a/re/status_test.go b/re/status_test.go new file mode 100644 index 000000000..e534e920a --- /dev/null +++ b/re/status_test.go @@ -0,0 +1,205 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re_test + +import ( + "encoding/json" + "testing" + + "github.com/absmach/magistrala/re" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToStatus(t *testing.T) { + cases := []struct { + desc string + status string + res re.Status + err error + }{ + { + desc: "convert enabled status", + status: re.Enabled, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "convert empty string to enabled status", + status: "", + res: re.EnabledStatus, + err: nil, + }, + { + desc: "convert disabled status", + status: re.Disabled, + res: re.DisabledStatus, + err: nil, + }, + { + desc: "convert deleted status", + status: re.Deleted, + res: re.DeletedStatus, + err: nil, + }, + { + desc: "convert all status", + status: re.All, + res: re.AllStatus, + err: nil, + }, + { + desc: "convert invalid status", + status: "invalid", + res: re.Status(0), + err: svcerr.ErrInvalidStatus, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + status, err := re.ToStatus(tc.status) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.res, status) + }) + } +} + +func TestStatusString(t *testing.T) { + cases := []struct { + desc string + status re.Status + res string + }{ + { + desc: "enabled status to string", + status: re.EnabledStatus, + res: re.Enabled, + }, + { + desc: "disabled status to string", + status: re.DisabledStatus, + res: re.Disabled, + }, + { + desc: "deleted status to string", + status: re.DeletedStatus, + res: re.Deleted, + }, + { + desc: "all status to string", + status: re.AllStatus, + res: re.All, + }, + { + desc: "unknown status to string", + status: re.Status(99), + res: re.Unknown, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + assert.Equal(t, tc.res, tc.status.String()) + }) + } +} + +func TestStatusMarshalJSON(t *testing.T) { + cases := []struct { + desc string + status re.Status + res string + }{ + { + desc: "marshal enabled status", + status: re.EnabledStatus, + res: `"enabled"`, + }, + { + desc: "marshal disabled status", + status: re.DisabledStatus, + res: `"disabled"`, + }, + { + desc: "marshal deleted status", + status: re.DeletedStatus, + res: `"deleted"`, + }, + { + desc: "marshal all status", + status: re.AllStatus, + res: `"all"`, + }, + { + desc: "marshal unknown status", + status: re.Status(99), + res: `"unknown"`, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data, err := json.Marshal(tc.status) + require.NoError(t, err) + assert.Equal(t, tc.res, string(data)) + }) + } +} + +func TestStatusUnmarshalJSON(t *testing.T) { + cases := []struct { + desc string + data string + res re.Status + err error + }{ + { + desc: "unmarshal enabled status", + data: `"enabled"`, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "unmarshal disabled status", + data: `"disabled"`, + res: re.DisabledStatus, + err: nil, + }, + { + desc: "unmarshal deleted status", + data: `"deleted"`, + res: re.DeletedStatus, + err: nil, + }, + { + desc: "unmarshal all status", + data: `"all"`, + res: re.AllStatus, + err: nil, + }, + { + desc: "unmarshal empty string to enabled status", + data: `""`, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "unmarshal invalid status", + data: `"invalid"`, + res: re.Status(0), + err: svcerr.ErrInvalidStatus, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var status re.Status + err := json.Unmarshal([]byte(tc.data), &status) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.res, status) + }) + } +}