MG-132 - Improve RE tests (#346)

* initial implementation

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* add coverage for api tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* add coverage for api tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* add tests for handler

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* add tests for start schedular

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix race condition

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* address comments

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix addrule test

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix list rule method

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* use sorting for the slice

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fetch supermq

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

---------

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2025-11-10 20:03:10 +03:00
committed by GitHub
parent a26d84b12d
commit 257db27769
11 changed files with 2380 additions and 49 deletions
+1 -1
View File
@@ -1,7 +1,7 @@
# Copyright (c) Abstract Machines
# SPDX-License-Identifier: Apache-2.0
FROM golang:1.25.3-alpine AS builder
FROM golang:1.25.3-alpine3.22 AS builder
ARG SVC
ARG GOARCH
ARG GOARM
+1 -1
View File
@@ -4,5 +4,5 @@
FROM scratch
ARG SVC
COPY $SVC /exe
COPY --from=alpine:latest /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
COPY --from=alpine:3.22 /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt
ENTRYPOINT ["/exe"]
+27 -27
View File
@@ -26,7 +26,7 @@ volumes:
services:
spicedb:
image: "authzed/spicedb:v1.37.0"
image: docker.io/authzed/spicedb:v1.37.0
container_name: supermq-spicedb
command: "serve"
restart: "always"
@@ -44,7 +44,7 @@ services:
- spicedb-migrate
spicedb-migrate:
image: "authzed/spicedb:v1.37.0"
image: docker.io/authzed/spicedb:v1.37.0
container_name: supermq-spicedb-migrate
command: "migrate head"
restart: "on-failure"
@@ -57,7 +57,7 @@ services:
- spicedb-db
spicedb-db:
image: "postgres:16.2-alpine"
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-spicedb-db
networks:
- supermq-base-net
@@ -72,7 +72,7 @@ services:
command: ["postgres", "-c", "track_commit_timestamp=on"]
auth-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-auth-db
restart: on-failure
ports:
@@ -87,7 +87,7 @@ services:
- supermq-auth-db-volume:/var/lib/postgresql/data
auth-redis:
image: redis:7.2.4-alpine
image: docker.io/redis:8.2.2-alpine3.22
container_name: supermq-auth-redis
restart: on-failure
networks:
@@ -96,7 +96,7 @@ services:
- supermq-auth-redis-volume:/data
auth:
image: supermq/auth:${SMQ_RELEASE_TAG}
image: docker.io/supermq/auth:${SMQ_RELEASE_TAG}
container_name: supermq-auth
depends_on:
- auth-db
@@ -189,7 +189,7 @@ services:
create_host_path: true
domains-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-domains-db
restart: on-failure
ports:
@@ -204,7 +204,7 @@ services:
- supermq-domains-db-volume:/var/lib/postgresql/data
domains-redis:
image: redis:7.2.4-alpine
image: docker.io/redis:8.2.2-alpine3.22
container_name: supermq-domains-redis
restart: on-failure
networks:
@@ -213,7 +213,7 @@ services:
- supermq-domains-redis-volume:/data
domains:
image: supermq/domains:${SMQ_RELEASE_TAG}
image: docker.io/supermq/domains:${SMQ_RELEASE_TAG}
container_name: supermq-domains
depends_on:
- domains-db
@@ -380,7 +380,7 @@ services:
create_host_path: true
nginx:
image: nginx:1.25.4-alpine
image: docker.io/nginx:1.29.2-alpine3.22
container_name: supermq-nginx
restart: on-failure
volumes:
@@ -423,7 +423,7 @@ services:
hard: 65536
clients-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-clients-db
restart: on-failure
command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}"
@@ -440,7 +440,7 @@ services:
- supermq-clients-db-volume:/var/lib/postgresql/data
clients-redis:
image: redis:7.2.4-alpine
image: docker.io/redis:8.2.2-alpine3.22
container_name: supermq-clients-redis
restart: on-failure
networks:
@@ -449,7 +449,7 @@ services:
- supermq-clients-redis-volume:/data
clients:
image: supermq/clients:${SMQ_RELEASE_TAG}
image: docker.io/supermq/clients:${SMQ_RELEASE_TAG}
container_name: supermq-clients
depends_on:
- clients-db
@@ -616,7 +616,7 @@ services:
create_host_path: true
channels-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-channels-db
restart: on-failure
command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}"
@@ -633,7 +633,7 @@ services:
- supermq-channels-db-volume:/var/lib/postgresql/data
channels-redis:
image: redis:7.2.4-alpine
image: docker.io/redis:8.2.2-alpine3.22
container_name: supermq-channels-redis
restart: on-failure
networks:
@@ -642,7 +642,7 @@ services:
- supermq-channels-redis-volume:/data
channels:
image: supermq/channels:${SMQ_RELEASE_TAG}
image: docker.io/supermq/channels:${SMQ_RELEASE_TAG}
container_name: supermq-channels
depends_on:
- channels-db
@@ -807,7 +807,7 @@ services:
create_host_path: true
users-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-users-db
restart: on-failure
command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}"
@@ -824,7 +824,7 @@ services:
- supermq-users-db-volume:/var/lib/postgresql/data
users:
image: supermq/users:${SMQ_RELEASE_TAG}
image: docker.io/supermq/users:${SMQ_RELEASE_TAG}
container_name: supermq-users
depends_on:
- users-db
@@ -933,7 +933,7 @@ services:
create_host_path: true
groups-db:
image: postgres:16.2-alpine
image: docker.io/postgres:18.0-alpine3.22
container_name: supermq-groups-db
restart: on-failure
command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}"
@@ -950,7 +950,7 @@ services:
- supermq-groups-db-volume:/var/lib/postgresql/data
groups:
image: supermq/groups:${SMQ_RELEASE_TAG}
image: docker.io/supermq/groups:${SMQ_RELEASE_TAG}
container_name: supermq-groups
depends_on:
- groups-db
@@ -1113,7 +1113,7 @@ services:
create_host_path: true
jaeger:
image: jaegertracing/all-in-one:1.66.0
image: docker.io/jaegertracing/all-in-one:1.74.0
container_name: supermq-jaeger
environment:
COLLECTOR_OTLP_ENABLED: ${SMQ_JAEGER_COLLECTOR_OTLP_ENABLED}
@@ -1125,7 +1125,7 @@ services:
- supermq-base-net
mqtt-adapter:
image: supermq/mqtt:${SMQ_RELEASE_TAG}
image: docker.io/supermq/mqtt:${SMQ_RELEASE_TAG}
container_name: supermq-mqtt
depends_on:
- clients
@@ -1224,7 +1224,7 @@ services:
create_host_path: true
http-adapter:
image: supermq/http:${SMQ_RELEASE_TAG}
image: docker.io/supermq/http:${SMQ_RELEASE_TAG}
container_name: supermq-http
depends_on:
- clients
@@ -1336,7 +1336,7 @@ services:
create_host_path: true
coap-adapter:
image: supermq/coap:${SMQ_RELEASE_TAG}
image: docker.io/supermq/coap:${SMQ_RELEASE_TAG}
container_name: supermq-coap
depends_on:
- clients
@@ -1454,7 +1454,7 @@ services:
create_host_path: true
ws-adapter:
image: supermq/ws:${SMQ_RELEASE_TAG}
image: docker.io/supermq/ws:${SMQ_RELEASE_TAG}
container_name: supermq-ws
depends_on:
- clients
@@ -1566,7 +1566,7 @@ services:
create_host_path: true
rabbitmq:
image: rabbitmq:4.0.5-management-alpine
image: docker.io/rabbitmq:4.1.4-management-alpine
container_name: supermq-rabbitmq
restart: on-failure
environment:
@@ -1587,7 +1587,7 @@ services:
- supermq-base-net
nats:
image: nats:2.10.25-alpine
image: docker.io/nats:2.12.0-alpine3.22
container_name: supermq-nats
restart: on-failure
command: "--config=/etc/nats/nats.conf"
+1 -1
View File
@@ -134,7 +134,7 @@ func listRulesEndpoint(s re.Service) endpoint.Endpoint {
}
page, err := s.ListRules(ctx, session, req.PageMeta)
if err != nil {
return rulesPageRes{}, nil
return rulesPageRes{}, err
}
ret := rulesPageRes{
Page: page,
+163
View File
@@ -443,6 +443,14 @@ func TestListRulesEndpoint(t *testing.T) {
status: http.StatusBadRequest,
err: apiutil.ErrInvalidDirection,
},
{
desc: "list rules with invalid order",
domainID: domainID,
token: validToken,
query: "order=invalid",
status: http.StatusBadRequest,
err: apiutil.ErrInvalidOrder,
},
{
desc: "list rule with limit that is too big",
domainID: domainID,
@@ -507,6 +515,14 @@ func TestListRulesEndpoint(t *testing.T) {
status: http.StatusBadRequest,
err: apiutil.ErrInvalidQueryParams,
},
{
desc: "list rules with service error",
domainID: domainID,
token: validToken,
listRulesResponse: re.Page{},
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
@@ -827,6 +843,153 @@ func TestUpdateRuleTagsEndpoint(t *testing.T) {
}
}
func TestUpdateRuleScheduleEndpoint(t *testing.T) {
ts, svc, authn := newRuleEngineServer()
defer ts.Close()
updateScheduleReq := pkgSch.Schedule{
StartDateTime: future,
Time: future.Add(2 * time.Hour),
Recurring: pkgSch.Weekly,
RecurringPeriod: 2,
}
ruleWithSchedule := rule
ruleWithSchedule.Schedule = updateScheduleReq
cases := []struct {
desc string
token string
id string
domainID string
schedule pkgSch.Schedule
contentType string
session smqauthn.Session
svcResp re.Rule
svcErr error
status int
authnErr error
err error
}{
{
desc: "update rule schedule successfully",
token: validToken,
domainID: domainID,
id: validID,
schedule: updateScheduleReq,
contentType: contentType,
svcResp: ruleWithSchedule,
status: http.StatusOK,
err: nil,
},
{
desc: "update rule schedule with invalid token",
token: invalidToken,
session: smqauthn.Session{},
domainID: domainID,
id: validID,
schedule: updateScheduleReq,
contentType: contentType,
authnErr: svcerr.ErrAuthentication,
status: http.StatusUnauthorized,
err: svcerr.ErrAuthentication,
},
{
desc: "update rule schedule with empty token",
token: "",
session: smqauthn.Session{},
domainID: domainID,
id: validID,
schedule: updateScheduleReq,
contentType: contentType,
status: http.StatusUnauthorized,
err: apiutil.ErrBearerToken,
},
{
desc: "update rule schedule with empty domainID",
token: validToken,
id: validID,
schedule: updateScheduleReq,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrMissingDomainID,
},
{
desc: "update rule schedule with invalid content type",
token: validToken,
id: validID,
domainID: domainID,
schedule: updateScheduleReq,
contentType: "application/xml",
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "update rule schedule with start_datetime in past",
token: validToken,
id: validID,
domainID: domainID,
schedule: pkgSch.Schedule{
StartDateTime: past,
Time: future,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
},
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
},
{
desc: "update rule schedule with service error",
token: validToken,
id: validID,
domainID: domainID,
schedule: updateScheduleReq,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
},
{
desc: "update rule schedule with empty id",
token: validToken,
id: "",
domainID: domainID,
schedule: updateScheduleReq,
contentType: contentType,
status: http.StatusBadRequest,
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
data := toJSON(map[string]any{
"schedule": tc.schedule,
})
req := testRequest{
client: ts.Client(),
method: http.MethodPatch,
url: fmt.Sprintf("%s/%s/rules/%s/schedule", ts.URL, tc.domainID, tc.id),
contentType: tc.contentType,
token: tc.token,
body: strings.NewReader(data),
}
authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
svcCall := svc.On("UpdateRuleSchedule", mock.Anything, mock.Anything, mock.Anything).Return(tc.svcResp, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authCall.Unset()
})
}
}
func TestEnableRuleEndpoint(t *testing.T) {
ts, svc, authn := newRuleEngineServer()
defer ts.Close()
+1 -1
View File
@@ -56,7 +56,7 @@ func (req listRulesReq) validate() error {
switch req.Order {
case "", api.NameKey, api.CreatedAtOrder, api.UpdatedAtOrder:
default:
return apiutil.ErrInvalidOrder
return errors.Wrap(apiutil.ErrInvalidOrder, apiutil.ErrValidation)
}
if req.Dir != api.AscDir && req.Dir != api.DescDir {
+3 -3
View File
@@ -142,7 +142,7 @@ func (re *re) StartScheduler(ctx context.Context) error {
}
for _, r := range page.Rules {
go func(rule Rule) {
go func(rule Rule, dueTime time.Time) {
if _, err := re.repo.UpdateRuleDue(ctx, rule.ID, rule.Schedule.NextDue()); err != nil {
re.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update rule: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}}
return
@@ -153,10 +153,10 @@ func (re *re) StartScheduler(ctx context.Context) error {
Channel: rule.InputChannel,
Subtopic: rule.InputTopic,
Protocol: protocol,
Created: due.Unix(),
Created: dueTime.Unix(),
}
re.runInfo <- re.process(ctx, rule, msg)
}(r)
}(r, due)
}
// Reset due, it will reset in the page meta as well.
due = time.Now().UTC()
+998
View File
@@ -0,0 +1,998 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"context"
"fmt"
"sort"
"testing"
"time"
"github.com/0x6flab/namegenerator"
"github.com/absmach/magistrala/pkg/schedule"
"github.com/absmach/magistrala/re"
"github.com/absmach/magistrala/re/outputs"
"github.com/absmach/magistrala/re/postgres"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/uuid"
"github.com/stretchr/testify/assert"
)
const (
ascDir = "asc"
descDir = "desc"
nameOrder = "name"
createdAtOrder = "created_at"
updatedAtOrder = "updated_at"
)
var (
namegen = namegenerator.NewGenerator()
idProvider = uuid.New()
)
func TestAddRule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
Tags: []string{"test", "rule"},
InputChannel: generateUUID(t),
InputTopic: "temperature",
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Outputs: re.Outputs{
&outputs.Alarm{},
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
Metadata: map[string]any{
"key": "value",
},
}
scheduleName := namegen.Generate()
scheduleDomain := generateUUID(t)
scheduleChannel := generateUUID(t)
scheduleCreatedBy := generateUUID(t)
scheduleCreatedAt := time.Now().UTC().Truncate(time.Microsecond)
scheduleUpdatedBy := generateUUID(t)
scheduleUpdatedAt := time.Now().UTC().Truncate(time.Microsecond)
scheduleStartTime := time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond)
scheduleTime := time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond)
scheduleRule := re.Rule{
ID: generateUUID(t),
Name: scheduleName,
DomainID: scheduleDomain,
InputChannel: scheduleChannel,
InputTopic: "humidity",
Logic: re.Script{
Type: re.LuaType,
Value: "return value > 50",
},
Schedule: schedule.Schedule{
StartDateTime: scheduleStartTime,
Time: scheduleTime,
Recurring: schedule.Daily,
RecurringPeriod: 1,
},
Status: re.EnabledStatus,
CreatedAt: scheduleCreatedAt,
CreatedBy: scheduleCreatedBy,
UpdatedAt: scheduleUpdatedAt,
UpdatedBy: scheduleUpdatedBy,
Metadata: re.Metadata{},
}
outputsName := namegen.Generate()
outputsDomain := generateUUID(t)
outputsChannel := generateUUID(t)
outputsCreatedBy := generateUUID(t)
outputsCreatedAt := time.Now().UTC().Truncate(time.Microsecond)
outputsUpdatedBy := generateUUID(t)
outputsUpdatedAt := time.Now().UTC().Truncate(time.Microsecond)
outputsRuleID := generateUUID(t)
outputsRule := re.Rule{
ID: outputsRuleID,
Name: outputsName,
DomainID: outputsDomain,
InputChannel: outputsChannel,
Logic: re.Script{
Type: re.GoType,
Value: "func() bool { return true }",
},
Outputs: re.Outputs{
&outputs.ChannelPublisher{
Channel: generateUUID(t),
Topic: "alerts",
},
&outputs.SenML{},
},
Status: re.EnabledStatus,
CreatedAt: outputsCreatedAt,
CreatedBy: outputsCreatedBy,
UpdatedAt: outputsUpdatedAt,
UpdatedBy: outputsUpdatedBy,
Metadata: re.Metadata{},
}
cases := []struct {
desc string
rule re.Rule
resp re.Rule
err error
}{
{
desc: "valid rule",
rule: rule,
resp: rule,
err: nil,
},
{
desc: "duplicate rule",
rule: rule,
resp: re.Rule{},
err: repoerr.ErrConflict,
},
{
desc: "rule with schedule",
rule: scheduleRule,
resp: scheduleRule,
err: nil,
},
{
desc: "rule with outputs",
rule: outputsRule,
resp: outputsRule,
err: nil,
},
{
desc: "invalid metadata",
rule: re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Metadata: map[string]any{
"key": make(chan int),
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
resp: re.Rule{},
err: repoerr.ErrMalformedEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
addedRule, err := repo.AddRule(context.Background(), tc.rule)
if err == nil {
tc.resp.ID = addedRule.ID
assert.Equal(t, tc.resp, addedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, addedRule))
}
})
}
}
func TestViewRule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
InputTopic: "temperature",
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
Metadata: map[string]any{
"key": "value",
},
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
cases := []struct {
desc string
id string
resp re.Rule
err error
}{
{
desc: "valid rule",
id: rule.ID,
resp: rule,
err: nil,
},
{
desc: "non existing rule",
id: generateUUID(t),
resp: re.Rule{},
err: repoerr.ErrViewEntity,
},
{
desc: "empty id",
id: "",
resp: re.Rule{},
err: repoerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
retrievedRule, err := repo.ViewRule(context.Background(), tc.id)
assert.Equal(t, tc.resp, retrievedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, retrievedRule))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestUpdateRule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
InputTopic: "temperature",
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
Metadata: map[string]any{
"key": "value",
},
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
newInputChannel := generateUUID(t)
newUpdatedBy := generateUUID(t)
cases := []struct {
desc string
rule re.Rule
resp re.Rule
err error
}{
{
desc: "valid rule update",
rule: re.Rule{
ID: rule.ID,
Name: "updated-name",
InputChannel: newInputChannel,
InputTopic: "humidity",
Logic: re.Script{
Type: re.LuaType,
Value: "return value > 30",
},
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: newUpdatedBy,
Metadata: map[string]any{
"updated": "metadata",
},
},
resp: re.Rule{
ID: rule.ID,
Name: "updated-name",
DomainID: rule.DomainID,
InputChannel: newInputChannel,
InputTopic: "humidity",
Logic: re.Script{
Type: re.LuaType,
Value: "return value > 30",
},
Status: rule.Status,
CreatedAt: rule.CreatedAt,
CreatedBy: rule.CreatedBy,
UpdatedAt: time.Time{},
UpdatedBy: newUpdatedBy,
Metadata: map[string]any{
"updated": "metadata",
},
},
err: nil,
},
{
desc: "update non-existing rule",
rule: re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
InputChannel: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
resp: re.Rule{},
err: repoerr.ErrNotFound,
},
{
desc: "update with invalid metadata",
rule: re.Rule{
ID: rule.ID,
InputChannel: generateUUID(t),
Metadata: map[string]any{
"key": make(chan int),
},
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
resp: re.Rule{},
err: repoerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
updatedRule, err := repo.UpdateRule(context.Background(), tc.rule)
if tc.err == nil {
tc.resp.UpdatedAt = updatedRule.UpdatedAt
}
assert.Equal(t, tc.resp, updatedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, updatedRule))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestUpdateRuleStatus(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
cases := []struct {
desc string
rule re.Rule
status re.Status
err error
}{
{
desc: "disable rule",
rule: re.Rule{
ID: rule.ID,
Status: re.DisabledStatus,
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
status: re.DisabledStatus,
err: nil,
},
{
desc: "enable rule",
rule: re.Rule{
ID: rule.ID,
Status: re.EnabledStatus,
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
status: re.EnabledStatus,
err: nil,
},
{
desc: "update non-existing rule status",
rule: re.Rule{
ID: generateUUID(t),
Status: re.DisabledStatus,
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
updatedRule, err := repo.UpdateRuleStatus(context.Background(), tc.rule)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID))
assert.Equal(t, tc.status, updatedRule.Status, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.status, updatedRule.Status))
assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy))
}
})
}
}
func TestUpdateRuleTags(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Tags: []string{"tag1", "tag2"},
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
cases := []struct {
desc string
rule re.Rule
tags []string
err error
}{
{
desc: "update tags",
rule: re.Rule{
ID: rule.ID,
Tags: []string{"newtag1", "newtag2", "newtag3"},
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
tags: []string{"newtag1", "newtag2", "newtag3"},
err: nil,
},
{
desc: "update non-existing rule tags",
rule: re.Rule{
ID: generateUUID(t),
Tags: []string{"tag"},
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
updatedRule, err := repo.UpdateRuleTags(context.Background(), tc.rule)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID))
assert.Equal(t, tc.tags, updatedRule.Tags, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.tags, updatedRule.Tags))
assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy))
}
})
}
}
func TestUpdateRuleSchedule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
newSchedule := schedule.Schedule{
StartDateTime: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond),
Time: time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond),
Recurring: schedule.Weekly,
RecurringPeriod: 2,
}
cases := []struct {
desc string
rule re.Rule
schedule schedule.Schedule
err error
}{
{
desc: "update schedule",
rule: re.Rule{
ID: rule.ID,
Schedule: newSchedule,
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
schedule: newSchedule,
err: nil,
},
{
desc: "update non-existing rule schedule",
rule: re.Rule{
ID: generateUUID(t),
Schedule: newSchedule,
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
updatedRule, err := repo.UpdateRuleSchedule(context.Background(), tc.rule)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID))
assert.Equal(t, tc.schedule.Recurring, updatedRule.Schedule.Recurring, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.Recurring, updatedRule.Schedule.Recurring))
assert.Equal(t, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod))
assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy))
}
})
}
}
func TestUpdateRuleDue(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Schedule: schedule.Schedule{
Time: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond),
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
newDue := time.Now().UTC().Add(3 * time.Hour).Truncate(time.Microsecond)
cases := []struct {
desc string
id string
due time.Time
err error
}{
{
desc: "update due time",
id: rule.ID,
due: newDue,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
updatedRule, err := repo.UpdateRuleDue(context.Background(), tc.id, tc.due)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.id, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.id, updatedRule.ID))
assert.True(t, updatedRule.Schedule.Time.Sub(tc.due) < time.Second, fmt.Sprintf("%s: expected due time close to %v got %v\n", tc.desc, tc.due, updatedRule.Schedule.Time))
}
})
}
}
func TestListRules(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := generateUUID(t)
channelID := generateUUID(t)
items := make([]re.Rule, 100)
for i := range 100 {
items[i] = re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: domainID,
InputChannel: channelID,
Tags: []string{fmt.Sprintf("tag%d", i%10)},
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
if i%2 == 0 {
items[i].Status = re.DisabledStatus
}
if i%3 == 0 {
items[i].Schedule = schedule.Schedule{
Time: time.Now().UTC().Add(time.Duration(i) * time.Hour),
Recurring: schedule.Daily,
}
}
rule, err := repo.AddRule(context.Background(), items[i])
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
items[i].ID = rule.ID
}
cases := []struct {
desc string
pm re.PageMeta
count int
err error
}{
{
desc: "list first page",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
},
count: 10,
err: nil,
},
{
desc: "list with offset",
pm: re.PageMeta{
Offset: 10,
Limit: 20,
Status: re.AllStatus,
},
count: 20,
err: nil,
},
{
desc: "list by domain",
pm: re.PageMeta{
Domain: domainID,
Offset: 0,
Limit: 200,
Status: re.AllStatus,
},
count: 100,
err: nil,
},
{
desc: "list by channel",
pm: re.PageMeta{
InputChannel: channelID,
Offset: 0,
Limit: 200,
Status: re.AllStatus,
},
count: 100,
err: nil,
},
{
desc: "list enabled rules",
pm: re.PageMeta{
Status: re.EnabledStatus,
Offset: 0,
Limit: 200,
},
count: 50,
err: nil,
},
{
desc: "list disabled rules",
pm: re.PageMeta{
Status: re.DisabledStatus,
Offset: 0,
Limit: 200,
},
count: 50,
err: nil,
},
{
desc: "list by tag",
pm: re.PageMeta{
Tag: "tag1",
Offset: 0,
Limit: 200,
Status: re.AllStatus,
},
count: 10,
err: nil,
},
{
desc: "list with zero limit returns all",
pm: re.PageMeta{
Status: re.AllStatus,
},
count: 100,
err: nil,
},
{
desc: "list non-existing domain",
pm: re.PageMeta{
Domain: generateUUID(t),
Offset: 0,
Limit: 10,
Status: re.AllStatus,
},
count: 0,
err: nil,
},
{
desc: "list ordered by name ascending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: nameOrder,
Dir: ascDir,
},
count: 10,
err: nil,
},
{
desc: "list ordered by name descending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: nameOrder,
Dir: descDir,
},
count: 10,
err: nil,
},
{
desc: "list ordered by created_at ascending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: createdAtOrder,
Dir: ascDir,
},
count: 10,
err: nil,
},
{
desc: "list ordered by created_at descending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: createdAtOrder,
Dir: descDir,
},
count: 10,
err: nil,
},
{
desc: "list ordered by updated_at ascending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: updatedAtOrder,
Dir: ascDir,
},
count: 10,
err: nil,
},
{
desc: "list ordered by updated_at descending",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
Order: updatedAtOrder,
Dir: descDir,
},
count: 10,
err: nil,
},
{
desc: "list with default order (updated_at desc)",
pm: re.PageMeta{
Offset: 0,
Limit: 10,
Status: re.AllStatus,
},
count: 10,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListRules(context.Background(), tc.pm)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
}
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules)))
if len(page.Rules) > 1 {
switch tc.pm.Order {
case nameOrder:
if tc.pm.Dir == ascDir {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].Name <= page.Rules[j].Name
}), "Expected names to be sorted ascending")
} else {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].Name >= page.Rules[j].Name
}), "Expected names to be sorted descending")
}
case createdAtOrder:
if tc.pm.Dir == ascDir {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].CreatedAt.Before(page.Rules[j].CreatedAt)
}), "Expected created_at to be sorted ascending")
} else {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt)
}), "Expected created_at to be sorted descending")
}
case updatedAtOrder:
if tc.pm.Dir == ascDir {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].UpdatedAt.Before(page.Rules[j].UpdatedAt)
}), "Expected updated_at to be sorted ascending")
} else {
assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool {
return page.Rules[i].UpdatedAt.After(page.Rules[j].UpdatedAt)
}), "Expected updated_at to be sorted descending")
}
}
}
})
}
}
func TestRemoveRule(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM rules")
assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
rule := re.Rule{
ID: generateUUID(t),
Name: namegen.Generate(),
DomainID: generateUUID(t),
InputChannel: generateUUID(t),
Logic: re.Script{
Type: re.LuaType,
Value: "return true",
},
Status: re.EnabledStatus,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
CreatedBy: generateUUID(t),
UpdatedAt: time.Now().UTC().Truncate(time.Microsecond),
UpdatedBy: generateUUID(t),
}
rule, err := repo.AddRule(context.Background(), rule)
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
cases := []struct {
desc string
id string
err error
}{
{
desc: "remove existing rule",
id: rule.ID,
err: nil,
},
{
desc: "remove non-existing rule",
id: generateUUID(t),
err: repoerr.ErrNotFound,
},
{
desc: "remove already removed rule",
id: rule.ID,
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.RemoveRule(context.Background(), tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func generateUUID(t *testing.T) string {
ulid, err := idProvider.ID()
assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
return ulid
}
+93
View File
@@ -0,0 +1,93 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"database/sql"
"fmt"
"log"
"os"
"testing"
"time"
repostgres "github.com/absmach/magistrala/re/postgres"
"github.com/absmach/supermq/pkg/postgres"
"github.com/jmoiron/sqlx"
dockertest "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"go.opentelemetry.io/otel"
)
var (
db *sqlx.DB
database postgres.Database
tracer = otel.Tracer("repo_tests")
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "16.2-alpine",
Env: []string{
"POSTGRES_USER=test",
"POSTGRES_PASSWORD=test",
"POSTGRES_DB=test",
"listen_addresses = '*'",
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
port := container.GetPort("5432/tcp")
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
pool.MaxWait = 120 * time.Second
if err := pool.Retry(func() error {
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
db, err := sql.Open("pgx", url)
if err != nil {
return err
}
return db.Ping()
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
dbConfig := postgres.Config{
Host: "localhost",
Port: port,
User: "test",
Pass: "test",
Name: "test",
SSLMode: "disable",
SSLCert: "",
SSLKey: "",
SSLRootCert: "",
}
if db, err = postgres.Setup(dbConfig, *repostgres.Migration()); err != nil {
log.Fatalf("Could not setup test DB connection: %s", err)
}
database = postgres.NewDatabase(db, dbConfig, tracer)
code := m.Run()
// Defers will not be run when using os.Exit
db.Close()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
+887 -15
View File
File diff suppressed because it is too large Load Diff
+205
View File
@@ -0,0 +1,205 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package re_test
import (
"encoding/json"
"testing"
"github.com/absmach/magistrala/re"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestToStatus(t *testing.T) {
cases := []struct {
desc string
status string
res re.Status
err error
}{
{
desc: "convert enabled status",
status: re.Enabled,
res: re.EnabledStatus,
err: nil,
},
{
desc: "convert empty string to enabled status",
status: "",
res: re.EnabledStatus,
err: nil,
},
{
desc: "convert disabled status",
status: re.Disabled,
res: re.DisabledStatus,
err: nil,
},
{
desc: "convert deleted status",
status: re.Deleted,
res: re.DeletedStatus,
err: nil,
},
{
desc: "convert all status",
status: re.All,
res: re.AllStatus,
err: nil,
},
{
desc: "convert invalid status",
status: "invalid",
res: re.Status(0),
err: svcerr.ErrInvalidStatus,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
status, err := re.ToStatus(tc.status)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.res, status)
})
}
}
func TestStatusString(t *testing.T) {
cases := []struct {
desc string
status re.Status
res string
}{
{
desc: "enabled status to string",
status: re.EnabledStatus,
res: re.Enabled,
},
{
desc: "disabled status to string",
status: re.DisabledStatus,
res: re.Disabled,
},
{
desc: "deleted status to string",
status: re.DeletedStatus,
res: re.Deleted,
},
{
desc: "all status to string",
status: re.AllStatus,
res: re.All,
},
{
desc: "unknown status to string",
status: re.Status(99),
res: re.Unknown,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
assert.Equal(t, tc.res, tc.status.String())
})
}
}
func TestStatusMarshalJSON(t *testing.T) {
cases := []struct {
desc string
status re.Status
res string
}{
{
desc: "marshal enabled status",
status: re.EnabledStatus,
res: `"enabled"`,
},
{
desc: "marshal disabled status",
status: re.DisabledStatus,
res: `"disabled"`,
},
{
desc: "marshal deleted status",
status: re.DeletedStatus,
res: `"deleted"`,
},
{
desc: "marshal all status",
status: re.AllStatus,
res: `"all"`,
},
{
desc: "marshal unknown status",
status: re.Status(99),
res: `"unknown"`,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
data, err := json.Marshal(tc.status)
require.NoError(t, err)
assert.Equal(t, tc.res, string(data))
})
}
}
func TestStatusUnmarshalJSON(t *testing.T) {
cases := []struct {
desc string
data string
res re.Status
err error
}{
{
desc: "unmarshal enabled status",
data: `"enabled"`,
res: re.EnabledStatus,
err: nil,
},
{
desc: "unmarshal disabled status",
data: `"disabled"`,
res: re.DisabledStatus,
err: nil,
},
{
desc: "unmarshal deleted status",
data: `"deleted"`,
res: re.DeletedStatus,
err: nil,
},
{
desc: "unmarshal all status",
data: `"all"`,
res: re.AllStatus,
err: nil,
},
{
desc: "unmarshal empty string to enabled status",
data: `""`,
res: re.EnabledStatus,
err: nil,
},
{
desc: "unmarshal invalid status",
data: `"invalid"`,
res: re.Status(0),
err: svcerr.ErrInvalidStatus,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var status re.Status
err := json.Unmarshal([]byte(tc.data), &status)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.res, status)
})
}
}