MG-225 - Fix schedule validation (#222)

* initial implementation

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing linter

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix failing tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* add utc for reports

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* address comments

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fic wrapper

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* remove unused code

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* remove auth error

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

* fix tests

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>

---------

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2025-06-24 15:45:21 +03:00
committed by GitHub
parent 2c9c594100
commit 88d2ef3257
12 changed files with 101 additions and 90 deletions
+42 -23
View File
@@ -5,11 +5,15 @@ package schedule
import (
"encoding/json"
"errors"
"time"
"github.com/absmach/supermq/pkg/errors"
)
var ErrInvalidRecurringType = errors.New("invalid recurring type")
var (
ErrInvalidRecurringType = errors.New("invalid recurring type")
ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
)
// Type can be daily, weekly or monthly.
type Recurring uint
@@ -60,52 +64,67 @@ func (rt *Recurring) UnmarshalJSON(data []byte) error {
}
type Schedule struct {
StartDateTime time.Time `json:"start_datetime"` // When the schedule becomes active
Time time.Time `json:"time"` // Specific time for the rule to run
Recurring Recurring `json:"recurring"` // None, Daily, Weekly, Monthly
RecurringPeriod uint `json:"recurring_period"` // Controls how many intervals to skip between executions: 1 = every interval, 2 = every second interval, etc.
StartDateTime *time.Time `json:"start_datetime"` // When the schedule becomes active
Time time.Time `json:"time"` // Specific time for the rule to run
Recurring Recurring `json:"recurring"` // None, Daily, Weekly, Monthly
RecurringPeriod uint `json:"recurring_period"` // Controls how many intervals to skip between executions: 1 = every interval, 2 = every second interval, etc.
}
func (s Schedule) Validate() error {
if s.StartDateTime != nil {
now := time.Now().UTC()
if s.StartDateTime.Before(now) {
return ErrStartDateTimeInPast
}
}
return nil
}
func (s Schedule) MarshalJSON() ([]byte, error) {
type Alias Schedule
jTimes := struct {
StartDateTime string `json:"start_datetime"`
Time string `json:"time"`
StartDateTime *string `json:"start_datetime"`
Time string `json:"time"`
*Alias
}{
StartDateTime: s.StartDateTime.Format(time.RFC3339),
Time: s.Time.Format(time.RFC3339),
Alias: (*Alias)(&s),
Time: s.Time.Format(time.RFC3339),
Alias: (*Alias)(&s),
}
if s.StartDateTime != nil {
formatted := s.StartDateTime.Format(time.RFC3339)
jTimes.StartDateTime = &formatted
}
return json.Marshal(jTimes)
}
func (s *Schedule) UnmarshalJSON(data []byte) error {
type Alias Schedule
aux := struct {
temp := struct {
StartDateTime string `json:"start_datetime"`
Time string `json:"time"`
*Alias
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(data, &aux); err != nil {
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
startDateTime, err := time.Parse(time.RFC3339, aux.StartDateTime)
if err != nil {
return err
}
s.StartDateTime = startDateTime
if aux.Time != "" {
time, err := time.Parse(time.RFC3339, aux.Time)
s.StartDateTime = nil
if temp.StartDateTime != "" {
startDateTime, err := time.Parse(time.RFC3339, temp.StartDateTime)
if err != nil {
return err
}
s.Time = time
s.StartDateTime = &startDateTime
}
if temp.Time != "" {
parsedTime, err := time.Parse(time.RFC3339, temp.Time)
if err != nil {
return err
}
s.Time = parsedTime
}
return nil
}
+12 -12
View File
@@ -41,8 +41,9 @@ var (
validToken = "valid"
invalidToken = "invalid"
now = time.Now().UTC().Truncate(time.Minute)
future = now.Add(1 * time.Hour)
schedule = pkgSch.Schedule{
StartDateTime: now.Add(1 * time.Hour),
StartDateTime: &future,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
@@ -56,6 +57,13 @@ var (
"name": "test",
},
}
past = now.Add(-1 * time.Hour)
scheduleInPast = pkgSch.Schedule{
StartDateTime: &past,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: past,
}
)
type testRequest struct {
@@ -109,13 +117,6 @@ func TestAddRuleEndpoint(t *testing.T) {
ts, svc, authn := newRuleEngineServer()
defer ts.Close()
scheduleInPast := pkgSch.Schedule{
StartDateTime: now.Add(-1 * time.Hour),
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
}
ruleInPast := rule
ruleInPast.Schedule = scheduleInPast
@@ -204,7 +205,6 @@ func TestAddRuleEndpoint(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
rule: ruleInPast,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
},
@@ -215,9 +215,9 @@ func TestAddRuleEndpoint(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
rule: rule,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
svcErr: svcerr.ErrCreateEntity,
status: http.StatusUnprocessableEntity,
err: svcerr.ErrCreateEntity,
},
}
+4 -10
View File
@@ -4,8 +4,6 @@
package api
import (
"time"
"github.com/absmach/magistrala/pkg/schedule"
"github.com/absmach/magistrala/re"
api "github.com/absmach/supermq/api/http"
@@ -19,8 +17,6 @@ const (
MaxTitleSize = 37
)
var ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
type addRuleReq struct {
re.Rule
}
@@ -29,9 +25,8 @@ func (req addRuleReq) validate() error {
if len(req.Name) > api.MaxNameSize || req.Name == "" {
return apiutil.ErrNameSize
}
now := time.Now().UTC()
if req.Schedule.StartDateTime.Before(now) {
return errors.Wrap(ErrStartDateTimeInPast, apiutil.ErrValidation)
if err := req.Rule.Schedule.Validate(); err != nil {
return errors.Wrap(err, apiutil.ErrValidation)
}
return nil
}
@@ -101,9 +96,8 @@ func (req updateRuleScheduleReq) validate() error {
return apiutil.ErrMissingID
}
now := time.Now().UTC()
if req.Schedule.StartDateTime.Before(now) {
return errors.Wrap(ErrStartDateTimeInPast, apiutil.ErrValidation)
if err := req.Schedule.Validate(); err != nil {
return errors.Wrap(err, apiutil.ErrValidation)
}
return nil
+4 -3
View File
@@ -53,8 +53,9 @@ func ruleToDb(r re.Rule) (dbRule, error) {
for _, v := range r.Logic.Outputs {
lo = append(lo, int32(v))
}
start := sql.NullTime{Time: r.Schedule.StartDateTime}
if !r.Schedule.StartDateTime.IsZero() {
start := sql.NullTime{}
if r.Schedule.StartDateTime != nil && !r.Schedule.StartDateTime.IsZero() {
start.Time = *r.Schedule.StartDateTime
start.Valid = true
}
t := sql.NullTime{Time: r.Schedule.Time}
@@ -121,7 +122,7 @@ func dbToRule(dto dbRule) (re.Rule, error) {
OutputChannel: fromNullString(dto.OutputChannel),
OutputTopic: fromNullString(dto.OutputTopic),
Schedule: schedule.Schedule{
StartDateTime: dto.StartDateTime.Time,
StartDateTime: &dto.StartDateTime.Time,
Time: dto.Time.Time,
Recurring: dto.Recurring,
RecurringPeriod: dto.RecurringPeriod,
+4 -4
View File
@@ -56,10 +56,10 @@ func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (Rule,
r.DomainID = session.DomainID
r.Status = EnabledStatus
if r.Schedule.StartDateTime.IsZero() {
r.Schedule.StartDateTime = now
if r.Schedule.StartDateTime == nil || r.Schedule.StartDateTime.IsZero() {
r.Schedule.StartDateTime = &now
}
r.Schedule.Time = r.Schedule.StartDateTime
r.Schedule.Time = *r.Schedule.StartDateTime
rule, err := re.repo.AddRule(ctx, r)
if err != nil {
@@ -103,7 +103,7 @@ func (re *re) UpdateRuleTags(ctx context.Context, session authn.Session, r Rule)
func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error) {
r.UpdatedAt = time.Now().UTC()
r.UpdatedBy = session.UserID
r.Schedule.Time = r.Schedule.StartDateTime
r.Schedule.Time = *r.Schedule.StartDateTime
rule, err := re.repo.UpdateRuleSchedule(ctx, r)
if err != nil {
return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
+11 -10
View File
@@ -28,15 +28,16 @@ import (
)
var (
namegen = namegenerator.NewGenerator()
userID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
ruleName = namegen.Generate()
ruleID = testsutil.GenerateUUID(&testing.T{})
Tags = []string{"tag1", "tag2"}
inputChannel = "test.channel"
schedule = pkgSch.Schedule{
StartDateTime: time.Now().Add(-time.Hour),
namegen = namegenerator.NewGenerator()
userID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
ruleName = namegen.Generate()
ruleID = testsutil.GenerateUUID(&testing.T{})
Tags = []string{"tag1", "tag2"}
inputChannel = "test.channel"
StartDateTime = time.Now().Add(-time.Hour)
schedule = pkgSch.Schedule{
StartDateTime: &StartDateTime,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: time.Now().Add(-time.Hour),
@@ -344,7 +345,7 @@ func TestListRules(t *testing.T) {
Recurring: pkgSch.Daily,
Time: now.Add(1 * time.Hour),
RecurringPeriod: 1,
StartDateTime: now.Add(-1 * time.Hour),
StartDateTime: &now,
},
}
rules = append(rules, r)
+7 -7
View File
@@ -41,11 +41,12 @@ var (
validToken = "valid"
invalidToken = "invalid"
now = time.Now().UTC().Truncate(time.Minute)
future = now.Add(1 * time.Hour)
schedule = pkgSch.Schedule{
StartDateTime: now.Add(1 * time.Hour),
StartDateTime: &future,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
Time: future,
}
reportConfig = reports.ReportConfig{
ID: validID,
@@ -126,7 +127,7 @@ func TestAddReportConfigEndpoint(t *testing.T) {
defer ts.Close()
scheduleInPast := pkgSch.Schedule{
StartDateTime: now.Add(-1 * time.Hour),
StartDateTime: &now,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: now,
@@ -203,7 +204,6 @@ func TestAddReportConfigEndpoint(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
cfg: reportInPast,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
},
@@ -214,9 +214,9 @@ func TestAddReportConfigEndpoint(t *testing.T) {
authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID},
cfg: reportConfig,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
status: http.StatusForbidden,
err: svcerr.ErrAuthorization,
svcErr: svcerr.ErrCreateEntity,
status: http.StatusUnprocessableEntity,
err: svcerr.ErrCreateEntity,
},
}
+4 -9
View File
@@ -5,7 +5,6 @@ package api
import (
"fmt"
"time"
"github.com/absmach/magistrala/pkg/schedule"
"github.com/absmach/magistrala/reports"
@@ -22,8 +21,6 @@ const (
errInvalidMetric = "invalid metric[%d]: %w"
)
var ErrStartDateTimeInPast = errors.New("start_datetime must be greater than or equal to current time")
var (
errInvalidReportAction = errors.New("invalid report action")
errMetricsNotProvided = errors.New("metrics not provided")
@@ -41,9 +38,8 @@ func (req addReportConfigReq) validate() error {
if req.Name == "" {
return apiutil.ErrMissingName
}
now := time.Now().UTC()
if req.Schedule.StartDateTime.Before(now) {
return errors.Wrap(apiutil.ErrValidation, ErrStartDateTimeInPast)
if err := req.Schedule.Validate(); err != nil {
return errors.Wrap(err, apiutil.ErrValidation)
}
return validateReportConfig(req.ReportConfig, false, false)
}
@@ -91,9 +87,8 @@ func (req updateReportScheduleReq) validate() error {
return apiutil.ErrMissingID
}
now := time.Now().UTC()
if req.Schedule.StartDateTime.Before(now) {
return errors.Wrap(apiutil.ErrValidation, ErrStartDateTimeInPast)
if err := req.Schedule.Validate(); err != nil {
return errors.Wrap(err, apiutil.ErrValidation)
}
return nil
+4 -4
View File
@@ -60,9 +60,9 @@ func reportToDb(r reports.ReportConfig) (dbReport, error) {
}
email = e
}
start := sql.NullTime{Time: r.Schedule.StartDateTime}
if !r.Schedule.StartDateTime.IsZero() {
start := sql.NullTime{}
if r.Schedule.StartDateTime != nil && !r.Schedule.StartDateTime.IsZero() {
start.Time = *r.Schedule.StartDateTime
start.Valid = true
}
t := sql.NullTime{Time: r.Schedule.Time}
@@ -120,7 +120,7 @@ func dbToReport(dto dbReport) (reports.ReportConfig, error) {
Config: &config,
Metrics: metrics,
Schedule: schedule.Schedule{
StartDateTime: dto.StartDateTime.Time,
StartDateTime: &dto.StartDateTime.Time,
Time: dto.Due.Time,
Recurring: dto.Recurring,
RecurringPeriod: dto.RecurringPeriod,
+2 -2
View File
@@ -122,7 +122,7 @@ type Metric struct {
Name string `json:"name,omitempty"` // Mandatory field
Subtopic string `json:"subtopic,omitempty"` // Optional field
Protocol string `json:"protocol,omitempty"` // Optional field
Format string `json:"format,omitiempty"` // Optional field
Format string `json:"format,omitempty"` // Optional field
}
type ReqMetric struct {
@@ -131,7 +131,7 @@ type ReqMetric struct {
Name string `json:"name,omitempty"` // Mandatory field
Subtopic string `json:"subtopic,omitempty"` // Optional field
Protocol string `json:"protocol,omitempty"` // Optional field
Format string `json:"format,omitiempty"` // Optional field
Format string `json:"format,omitempty"` // Optional field
}
func (rm ReqMetric) Validate() error {
+5 -5
View File
@@ -49,17 +49,17 @@ func (r *report) AddReportConfig(ctx context.Context, session authn.Session, cfg
return ReportConfig{}, err
}
now := time.Now()
now := time.Now().UTC()
cfg.ID = id
cfg.CreatedAt = now
cfg.CreatedBy = session.UserID
cfg.DomainID = session.DomainID
cfg.Status = EnabledStatus
if cfg.Schedule.StartDateTime.IsZero() {
cfg.Schedule.StartDateTime = now
if cfg.Schedule.StartDateTime == nil || cfg.Schedule.StartDateTime.IsZero() {
cfg.Schedule.StartDateTime = &now
}
cfg.Schedule.Time = cfg.Schedule.StartDateTime
cfg.Schedule.Time = *cfg.Schedule.StartDateTime
reportConfig, err := r.repo.AddReportConfig(ctx, cfg)
if err != nil {
@@ -92,7 +92,7 @@ func (r *report) UpdateReportConfig(ctx context.Context, session authn.Session,
func (r *report) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) {
cfg.UpdatedAt = time.Now().UTC()
cfg.UpdatedBy = session.UserID
cfg.Schedule.Time = cfg.Schedule.StartDateTime
cfg.Schedule.Time = *cfg.Schedule.StartDateTime
c, err := r.repo.UpdateReportSchedule(ctx, cfg)
if err != nil {
return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
+2 -1
View File
@@ -30,8 +30,9 @@ var (
namegen = namegenerator.NewGenerator()
userID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
now = time.Now().UTC()
schedule = pkgSch.Schedule{
StartDateTime: time.Now().Add(-time.Hour),
StartDateTime: &now,
Recurring: pkgSch.Daily,
RecurringPeriod: 1,
Time: time.Now().Add(-time.Hour),