mirror of
https://github.com/absmach/supermq.git
synced 2026-06-23 07:00:25 +00:00
MG-225 - Fix schedule validation (#222)
* initial implementation Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix failing linter Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix tests Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix failing tests Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * add utc for reports Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * address comments Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fic wrapper Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove unused code Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * remove auth error Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> * fix tests Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> --------- Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
+42
-23
@@ -5,11 +5,15 @@ package schedule
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
)
|
||||
|
||||
var ErrInvalidRecurringType = errors.New("invalid recurring type")
|
||||
var (
|
||||
ErrInvalidRecurringType = errors.New("invalid recurring type")
|
||||
ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
|
||||
)
|
||||
|
||||
// Type can be daily, weekly or monthly.
|
||||
type Recurring uint
|
||||
@@ -60,52 +64,67 @@ func (rt *Recurring) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
|
||||
type Schedule struct {
|
||||
StartDateTime time.Time `json:"start_datetime"` // When the schedule becomes active
|
||||
Time time.Time `json:"time"` // Specific time for the rule to run
|
||||
Recurring Recurring `json:"recurring"` // None, Daily, Weekly, Monthly
|
||||
RecurringPeriod uint `json:"recurring_period"` // Controls how many intervals to skip between executions: 1 = every interval, 2 = every second interval, etc.
|
||||
StartDateTime *time.Time `json:"start_datetime"` // When the schedule becomes active
|
||||
Time time.Time `json:"time"` // Specific time for the rule to run
|
||||
Recurring Recurring `json:"recurring"` // None, Daily, Weekly, Monthly
|
||||
RecurringPeriod uint `json:"recurring_period"` // Controls how many intervals to skip between executions: 1 = every interval, 2 = every second interval, etc.
|
||||
}
|
||||
|
||||
func (s Schedule) Validate() error {
|
||||
if s.StartDateTime != nil {
|
||||
now := time.Now().UTC()
|
||||
if s.StartDateTime.Before(now) {
|
||||
return ErrStartDateTimeInPast
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s Schedule) MarshalJSON() ([]byte, error) {
|
||||
type Alias Schedule
|
||||
jTimes := struct {
|
||||
StartDateTime string `json:"start_datetime"`
|
||||
Time string `json:"time"`
|
||||
StartDateTime *string `json:"start_datetime"`
|
||||
Time string `json:"time"`
|
||||
*Alias
|
||||
}{
|
||||
StartDateTime: s.StartDateTime.Format(time.RFC3339),
|
||||
Time: s.Time.Format(time.RFC3339),
|
||||
Alias: (*Alias)(&s),
|
||||
Time: s.Time.Format(time.RFC3339),
|
||||
Alias: (*Alias)(&s),
|
||||
}
|
||||
if s.StartDateTime != nil {
|
||||
formatted := s.StartDateTime.Format(time.RFC3339)
|
||||
jTimes.StartDateTime = &formatted
|
||||
}
|
||||
|
||||
return json.Marshal(jTimes)
|
||||
}
|
||||
|
||||
func (s *Schedule) UnmarshalJSON(data []byte) error {
|
||||
type Alias Schedule
|
||||
aux := struct {
|
||||
temp := struct {
|
||||
StartDateTime string `json:"start_datetime"`
|
||||
Time string `json:"time"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(s),
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startDateTime, err := time.Parse(time.RFC3339, aux.StartDateTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.StartDateTime = startDateTime
|
||||
|
||||
if aux.Time != "" {
|
||||
time, err := time.Parse(time.RFC3339, aux.Time)
|
||||
s.StartDateTime = nil
|
||||
if temp.StartDateTime != "" {
|
||||
startDateTime, err := time.Parse(time.RFC3339, temp.StartDateTime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Time = time
|
||||
s.StartDateTime = &startDateTime
|
||||
}
|
||||
if temp.Time != "" {
|
||||
parsedTime, err := time.Parse(time.RFC3339, temp.Time)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Time = parsedTime
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+12
-12
@@ -41,8 +41,9 @@ var (
|
||||
validToken = "valid"
|
||||
invalidToken = "invalid"
|
||||
now = time.Now().UTC().Truncate(time.Minute)
|
||||
future = now.Add(1 * time.Hour)
|
||||
schedule = pkgSch.Schedule{
|
||||
StartDateTime: now.Add(1 * time.Hour),
|
||||
StartDateTime: &future,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
@@ -56,6 +57,13 @@ var (
|
||||
"name": "test",
|
||||
},
|
||||
}
|
||||
past = now.Add(-1 * time.Hour)
|
||||
scheduleInPast = pkgSch.Schedule{
|
||||
StartDateTime: &past,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: past,
|
||||
}
|
||||
)
|
||||
|
||||
type testRequest struct {
|
||||
@@ -109,13 +117,6 @@ func TestAddRuleEndpoint(t *testing.T) {
|
||||
ts, svc, authn := newRuleEngineServer()
|
||||
defer ts.Close()
|
||||
|
||||
scheduleInPast := pkgSch.Schedule{
|
||||
StartDateTime: now.Add(-1 * time.Hour),
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
}
|
||||
|
||||
ruleInPast := rule
|
||||
ruleInPast.Schedule = scheduleInPast
|
||||
|
||||
@@ -204,7 +205,6 @@ func TestAddRuleEndpoint(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
rule: ruleInPast,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
@@ -215,9 +215,9 @@ func TestAddRuleEndpoint(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
rule: rule,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
svcErr: svcerr.ErrCreateEntity,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
+4
-10
@@ -4,8 +4,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/schedule"
|
||||
"github.com/absmach/magistrala/re"
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
@@ -19,8 +17,6 @@ const (
|
||||
MaxTitleSize = 37
|
||||
)
|
||||
|
||||
var ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
|
||||
|
||||
type addRuleReq struct {
|
||||
re.Rule
|
||||
}
|
||||
@@ -29,9 +25,8 @@ func (req addRuleReq) validate() error {
|
||||
if len(req.Name) > api.MaxNameSize || req.Name == "" {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if req.Schedule.StartDateTime.Before(now) {
|
||||
return errors.Wrap(ErrStartDateTimeInPast, apiutil.ErrValidation)
|
||||
if err := req.Rule.Schedule.Validate(); err != nil {
|
||||
return errors.Wrap(err, apiutil.ErrValidation)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -101,9 +96,8 @@ func (req updateRuleScheduleReq) validate() error {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if req.Schedule.StartDateTime.Before(now) {
|
||||
return errors.Wrap(ErrStartDateTimeInPast, apiutil.ErrValidation)
|
||||
if err := req.Schedule.Validate(); err != nil {
|
||||
return errors.Wrap(err, apiutil.ErrValidation)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+4
-3
@@ -53,8 +53,9 @@ func ruleToDb(r re.Rule) (dbRule, error) {
|
||||
for _, v := range r.Logic.Outputs {
|
||||
lo = append(lo, int32(v))
|
||||
}
|
||||
start := sql.NullTime{Time: r.Schedule.StartDateTime}
|
||||
if !r.Schedule.StartDateTime.IsZero() {
|
||||
start := sql.NullTime{}
|
||||
if r.Schedule.StartDateTime != nil && !r.Schedule.StartDateTime.IsZero() {
|
||||
start.Time = *r.Schedule.StartDateTime
|
||||
start.Valid = true
|
||||
}
|
||||
t := sql.NullTime{Time: r.Schedule.Time}
|
||||
@@ -121,7 +122,7 @@ func dbToRule(dto dbRule) (re.Rule, error) {
|
||||
OutputChannel: fromNullString(dto.OutputChannel),
|
||||
OutputTopic: fromNullString(dto.OutputTopic),
|
||||
Schedule: schedule.Schedule{
|
||||
StartDateTime: dto.StartDateTime.Time,
|
||||
StartDateTime: &dto.StartDateTime.Time,
|
||||
Time: dto.Time.Time,
|
||||
Recurring: dto.Recurring,
|
||||
RecurringPeriod: dto.RecurringPeriod,
|
||||
|
||||
+4
-4
@@ -56,10 +56,10 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule,
|
||||
r.DomainID = session.DomainID
|
||||
r.Status = EnabledStatus
|
||||
|
||||
if r.Schedule.StartDateTime.IsZero() {
|
||||
r.Schedule.StartDateTime = now
|
||||
if r.Schedule.StartDateTime == nil || r.Schedule.StartDateTime.IsZero() {
|
||||
r.Schedule.StartDateTime = &now
|
||||
}
|
||||
r.Schedule.Time = r.Schedule.StartDateTime
|
||||
r.Schedule.Time = *r.Schedule.StartDateTime
|
||||
|
||||
rule, err := re.repo.AddRule(ctx, r)
|
||||
if err != nil {
|
||||
@@ -103,7 +103,7 @@ func (re *re) UpdateRuleTags(ctx context.Context, session authn.Session, r Rule)
|
||||
func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error) {
|
||||
r.UpdatedAt = time.Now().UTC()
|
||||
r.UpdatedBy = session.UserID
|
||||
r.Schedule.Time = r.Schedule.StartDateTime
|
||||
r.Schedule.Time = *r.Schedule.StartDateTime
|
||||
rule, err := re.repo.UpdateRuleSchedule(ctx, r)
|
||||
if err != nil {
|
||||
return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
|
||||
+11
-10
@@ -28,15 +28,16 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
namegen = namegenerator.NewGenerator()
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
ruleName = namegen.Generate()
|
||||
ruleID = testsutil.GenerateUUID(&testing.T{})
|
||||
Tags = []string{"tag1", "tag2"}
|
||||
inputChannel = "test.channel"
|
||||
schedule = pkgSch.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
namegen = namegenerator.NewGenerator()
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
ruleName = namegen.Generate()
|
||||
ruleID = testsutil.GenerateUUID(&testing.T{})
|
||||
Tags = []string{"tag1", "tag2"}
|
||||
inputChannel = "test.channel"
|
||||
StartDateTime = time.Now().Add(-time.Hour)
|
||||
schedule = pkgSch.Schedule{
|
||||
StartDateTime: &StartDateTime,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
@@ -344,7 +345,7 @@ func TestListRules(t *testing.T) {
|
||||
Recurring: pkgSch.Daily,
|
||||
Time: now.Add(1 * time.Hour),
|
||||
RecurringPeriod: 1,
|
||||
StartDateTime: now.Add(-1 * time.Hour),
|
||||
StartDateTime: &now,
|
||||
},
|
||||
}
|
||||
rules = append(rules, r)
|
||||
|
||||
@@ -41,11 +41,12 @@ var (
|
||||
validToken = "valid"
|
||||
invalidToken = "invalid"
|
||||
now = time.Now().UTC().Truncate(time.Minute)
|
||||
future = now.Add(1 * time.Hour)
|
||||
schedule = pkgSch.Schedule{
|
||||
StartDateTime: now.Add(1 * time.Hour),
|
||||
StartDateTime: &future,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
Time: future,
|
||||
}
|
||||
reportConfig = reports.ReportConfig{
|
||||
ID: validID,
|
||||
@@ -126,7 +127,7 @@ func TestAddReportConfigEndpoint(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
scheduleInPast := pkgSch.Schedule{
|
||||
StartDateTime: now.Add(-1 * time.Hour),
|
||||
StartDateTime: &now,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: now,
|
||||
@@ -203,7 +204,6 @@ func TestAddReportConfigEndpoint(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
cfg: reportInPast,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
@@ -214,9 +214,9 @@ func TestAddReportConfigEndpoint(t *testing.T) {
|
||||
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
|
||||
cfg: reportConfig,
|
||||
contentType: contentType,
|
||||
svcErr: svcerr.ErrAuthorization,
|
||||
status: http.StatusForbidden,
|
||||
err: svcerr.ErrAuthorization,
|
||||
svcErr: svcerr.ErrCreateEntity,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/pkg/schedule"
|
||||
"github.com/absmach/magistrala/reports"
|
||||
@@ -22,8 +21,6 @@ const (
|
||||
errInvalidMetric = "invalid metric[%d]: %w"
|
||||
)
|
||||
|
||||
var ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
|
||||
|
||||
var (
|
||||
errInvalidReportAction = errors.New("invalid report action")
|
||||
errMetricsNotProvided = errors.New("metrics not provided")
|
||||
@@ -41,9 +38,8 @@ func (req addReportConfigReq) validate() error {
|
||||
if req.Name == "" {
|
||||
return apiutil.ErrMissingName
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if req.Schedule.StartDateTime.Before(now) {
|
||||
return errors.Wrap(apiutil.ErrValidation, ErrStartDateTimeInPast)
|
||||
if err := req.Schedule.Validate(); err != nil {
|
||||
return errors.Wrap(err, apiutil.ErrValidation)
|
||||
}
|
||||
return validateReportConfig(req.ReportConfig, false, false)
|
||||
}
|
||||
@@ -91,9 +87,8 @@ func (req updateReportScheduleReq) validate() error {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
if req.Schedule.StartDateTime.Before(now) {
|
||||
return errors.Wrap(apiutil.ErrValidation, ErrStartDateTimeInPast)
|
||||
if err := req.Schedule.Validate(); err != nil {
|
||||
return errors.Wrap(err, apiutil.ErrValidation)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -60,9 +60,9 @@ func reportToDb(r reports.ReportConfig) (dbReport, error) {
|
||||
}
|
||||
email = e
|
||||
}
|
||||
|
||||
start := sql.NullTime{Time: r.Schedule.StartDateTime}
|
||||
if !r.Schedule.StartDateTime.IsZero() {
|
||||
start := sql.NullTime{}
|
||||
if r.Schedule.StartDateTime != nil && !r.Schedule.StartDateTime.IsZero() {
|
||||
start.Time = *r.Schedule.StartDateTime
|
||||
start.Valid = true
|
||||
}
|
||||
t := sql.NullTime{Time: r.Schedule.Time}
|
||||
@@ -120,7 +120,7 @@ func dbToReport(dto dbReport) (reports.ReportConfig, error) {
|
||||
Config: &config,
|
||||
Metrics: metrics,
|
||||
Schedule: schedule.Schedule{
|
||||
StartDateTime: dto.StartDateTime.Time,
|
||||
StartDateTime: &dto.StartDateTime.Time,
|
||||
Time: dto.Due.Time,
|
||||
Recurring: dto.Recurring,
|
||||
RecurringPeriod: dto.RecurringPeriod,
|
||||
|
||||
+2
-2
@@ -122,7 +122,7 @@ type Metric struct {
|
||||
Name string `json:"name,omitempty"` // Mandatory field
|
||||
Subtopic string `json:"subtopic,omitempty"` // Optional field
|
||||
Protocol string `json:"protocol,omitempty"` // Optional field
|
||||
Format string `json:"format,omitiempty"` // Optional field
|
||||
Format string `json:"format,omitempty"` // Optional field
|
||||
}
|
||||
|
||||
type ReqMetric struct {
|
||||
@@ -131,7 +131,7 @@ type ReqMetric struct {
|
||||
Name string `json:"name,omitempty"` // Mandatory field
|
||||
Subtopic string `json:"subtopic,omitempty"` // Optional field
|
||||
Protocol string `json:"protocol,omitempty"` // Optional field
|
||||
Format string `json:"format,omitiempty"` // Optional field
|
||||
Format string `json:"format,omitempty"` // Optional field
|
||||
}
|
||||
|
||||
func (rm ReqMetric) Validate() error {
|
||||
|
||||
+5
-5
@@ -49,17 +49,17 @@ func (r *report) AddReportConfig(ctx context.Context, session authn.Session, cfg
|
||||
return ReportConfig{}, err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
now := time.Now().UTC()
|
||||
cfg.ID = id
|
||||
cfg.CreatedAt = now
|
||||
cfg.CreatedBy = session.UserID
|
||||
cfg.DomainID = session.DomainID
|
||||
cfg.Status = EnabledStatus
|
||||
|
||||
if cfg.Schedule.StartDateTime.IsZero() {
|
||||
cfg.Schedule.StartDateTime = now
|
||||
if cfg.Schedule.StartDateTime == nil || cfg.Schedule.StartDateTime.IsZero() {
|
||||
cfg.Schedule.StartDateTime = &now
|
||||
}
|
||||
cfg.Schedule.Time = cfg.Schedule.StartDateTime
|
||||
cfg.Schedule.Time = *cfg.Schedule.StartDateTime
|
||||
|
||||
reportConfig, err := r.repo.AddReportConfig(ctx, cfg)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func (r *report) UpdateReportConfig(ctx context.Context, session authn.Session,
|
||||
func (r *report) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) {
|
||||
cfg.UpdatedAt = time.Now().UTC()
|
||||
cfg.UpdatedBy = session.UserID
|
||||
cfg.Schedule.Time = cfg.Schedule.StartDateTime
|
||||
cfg.Schedule.Time = *cfg.Schedule.StartDateTime
|
||||
c, err := r.repo.UpdateReportSchedule(ctx, cfg)
|
||||
if err != nil {
|
||||
return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
|
||||
@@ -30,8 +30,9 @@ var (
|
||||
namegen = namegenerator.NewGenerator()
|
||||
userID = testsutil.GenerateUUID(&testing.T{})
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
now = time.Now().UTC()
|
||||
schedule = pkgSch.Schedule{
|
||||
StartDateTime: time.Now().Add(-time.Hour),
|
||||
StartDateTime: &now,
|
||||
Recurring: pkgSch.Daily,
|
||||
RecurringPeriod: 1,
|
||||
Time: time.Now().Add(-time.Hour),
|
||||
|
||||
Reference in New Issue
Block a user