diff --git a/Makefile b/Makefile index e016c0080..44592f74f 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ MG_DOCKER_IMAGE_NAME_PREFIX ?= ghcr.io/absmach/magistrala BUILD_DIR = build -SERVICES = bootstrap provision re postgres-writer postgres-reader timescale-writer timescale-reader cli +SERVICES = bootstrap provision re postgres-writer postgres-reader timescale-writer timescale-reader cli alarms DOCKERS = $(addprefix docker_,$(SERVICES)) DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES)) CGO_ENABLED ?= 0 @@ -39,9 +39,9 @@ endif define compile_service CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) \ go build -tags $(MG_MESSAGE_BROKER_TYPE) -tags $(MG_ES_TYPE) -ldflags "-s -w \ - -X 'github.com/absmach/magistrala.BuildTime=$(TIME)' \ - -X 'github.com/absmach/magistrala.Version=$(VERSION)' \ - -X 'github.com/absmach/magistrala.Commit=$(COMMIT)'" \ + -X 'github.com/absmach/supermq.BuildTime=$(TIME)' \ + -X 'github.com/absmach/supermq.Version=$(VERSION)' \ + -X 'github.com/absmach/supermq.Commit=$(COMMIT)'" \ -o ${BUILD_DIR}/$(1) cmd/$(1)/main.go endef diff --git a/alarms/alarms.go b/alarms/alarms.go new file mode 100644 index 000000000..ed4cc6647 --- /dev/null +++ b/alarms/alarms.go @@ -0,0 +1,122 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "errors" + "time" + + "github.com/absmach/supermq/pkg/authn" +) + +const SeverityMax uint8 = 100 + +var ErrInvalidSeverity = errors.New("invalid severity. Must be between 0 and 100") + +type Metadata map[string]interface{} + +// Alarm represents an alarm instance. +type Alarm struct { + ID string `json:"id"` + RuleID string `json:"rule_id"` + DomainID string `json:"domain_id"` + ChannelID string `json:"channel_id"` + ClientID string `json:"client_id"` + Subtopic string `json:"subtopic"` + Status Status `json:"status"` + Measurement string `json:"measurement"` + Value string `json:"value"` + Unit string `json:"unit"` + Threshold string `json:"threshold"` + Cause string `json:"cause"` + Severity uint8 `json:"severity"` + AssigneeID string `json:"assignee_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` + AssignedAt time.Time `json:"assigned_at,omitempty"` + AssignedBy string `json:"assigned_by,omitempty"` + AcknowledgedAt time.Time `json:"acknowledged_at,omitempty"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + ResolvedAt time.Time `json:"resolved_at,omitempty"` + ResolvedBy string `json:"resolved_by,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` +} + +type AlarmsPage struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Alarms []Alarm `json:"alarms"` +} + +type PageMetadata struct { + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + DomainID string `json:"domain_id" db:"domain_id"` + ChannelID string `json:"channel_id" db:"channel_id"` + ClientID string `json:"client_id" db:"client_id"` + Subtopic string `json:"subtopic" db:"subtopic"` + RuleID string `json:"rule_id" db:"rule_id"` + Status Status `json:"status" db:"status"` + AssigneeID string `json:"assignee_id" db:"assignee_id"` + Severity uint8 `json:"severity" db:"severity"` + UpdatedBy string `json:"updated_by" db:"updated_by"` + AssignedBy string `json:"assigned_by" db:"assigned_by"` + AcknowledgedBy string `json:"acknowledged_by" db:"acknowledged_by"` + ResolvedBy string `json:"resolved_by" db:"resolved_by"` +} + +func (a Alarm) Validate() error { + if a.RuleID == "" { + return errors.New("rule_id is required") + } + if a.DomainID == "" { + return errors.New("domain_id is required") + } + if a.ChannelID == "" { + return errors.New("channel_id is required") + } + if a.ClientID == "" { + return errors.New("client_id is required") + } + if a.Subtopic == "" { + return errors.New("subtopic is required") + } + if a.Measurement == "" { + return errors.New("measurement is required") + } + if a.Value == "" { + return errors.New("value is required") + } + if a.Unit == "" { + return errors.New("unit is required") + } + if a.Cause == "" { + return errors.New("cause is required") + } + if a.Severity > SeverityMax { + return ErrInvalidSeverity + } + + return nil +} + +// Service specifies an API that must be fulfilled by the domain service. +type Service interface { + CreateAlarm(ctx context.Context, alarm Alarm) error + UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, session authn.Session, id string) (Alarm, error) + ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, session authn.Session, id string) error +} + +type Repository interface { + CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + UpdateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, alarmID, domainID string) (Alarm, error) + ListAlarms(ctx context.Context, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, id string) error +} diff --git a/alarms/alarms_test.go b/alarms/alarms_test.go new file mode 100644 index 000000000..2b1d5ac01 --- /dev/null +++ b/alarms/alarms_test.go @@ -0,0 +1,203 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "fmt" + "testing" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateAlarms(t *testing.T) { + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("rule_id is required"), + }, + { + desc: "missing domain_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("domain_id is required"), + }, + { + desc: "missing channel_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("channel_id is required"), + }, + { + desc: "missing client_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("client_id is required"), + }, + { + desc: "missing subtopic", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("subtopic is required"), + }, + { + desc: "missing measurement", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("measurement is required"), + }, + { + desc: "missing value", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("value is required"), + }, + { + desc: "missing unit", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Cause: "cause", + Severity: 100, + }, + err: errors.New("unit is required"), + }, + { + desc: "missing cause", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Severity: 100, + }, + err: errors.New("cause is required"), + }, + { + desc: "higher severity", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: alarms.SeverityMax + 1, + }, + err: alarms.ErrInvalidSeverity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.alarm.Validate() + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} diff --git a/alarms/api/doc.go b/alarms/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/alarms/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/alarms/api/endpoint.go b/alarms/api/endpoint.go new file mode 100644 index 000000000..f35e7c6d2 --- /dev/null +++ b/alarms/api/endpoint.go @@ -0,0 +1,105 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" +) + +func updateAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.UpdateAlarm(ctx, session, req.Alarm) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func viewAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.ViewAlarm(ctx, session, req.ID) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func listAlarmsEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listAlarmsReq) + if err := req.validate(); err != nil { + return alarmsPageRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmsPageRes{}, svcerr.ErrAuthorization + } + + alarms, err := svc.ListAlarms(ctx, session, req.PageMetadata) + if err != nil { + return alarmsPageRes{}, err + } + + return alarmsPageRes{ + AlarmsPage: alarms, + }, nil + } +} + +func deleteAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + if err := svc.DeleteAlarm(ctx, session, req.ID); err != nil { + return alarmRes{}, err + } + + return alarmRes{deleted: true}, nil + } +} diff --git a/alarms/api/requests.go b/alarms/api/requests.go new file mode 100644 index 000000000..54f677efb --- /dev/null +++ b/alarms/api/requests.go @@ -0,0 +1,36 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + + "github.com/absmach/magistrala/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" +) + +type alarmReq struct { + alarms.Alarm `json:",inline"` +} + +func (req alarmReq) validate() error { + if req.Alarm.ID == "" { + return errors.New("missing alarm id") + } + + return nil +} + +type listAlarmsReq struct { + alarms.PageMetadata +} + +func (req listAlarmsReq) validate() error { + if req.Limit > api.MaxLimitSize || req.Limit < 1 { + return apiutil.ErrLimitSize + } + + return nil +} diff --git a/alarms/api/responses.go b/alarms/api/responses.go new file mode 100644 index 000000000..4a499735d --- /dev/null +++ b/alarms/api/responses.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq" +) + +var ( + _ supermq.Response = (*alarmRes)(nil) + _ supermq.Response = (*alarmsPageRes)(nil) +) + +type alarmRes struct { + alarms.Alarm `json:",inline"` + created bool + deleted bool +} + +func (res alarmRes) Headers() map[string]string { + switch { + case res.created: + return map[string]string{ + "Location": fmt.Sprintf("/%s/alarms/%s", res.DomainID, res.ID), + } + default: + return map[string]string{} + } +} + +func (res alarmRes) Code() int { + switch { + case res.created: + return http.StatusCreated + case res.deleted: + return http.StatusNoContent + default: + return http.StatusOK + } +} + +func (res alarmRes) Empty() bool { + switch { + case res.deleted: + return true + default: + return false + } +} + +type alarmsPageRes struct { + alarms.AlarmsPage `json:",inline"` +} + +func (res alarmsPageRes) Headers() map[string]string { + return map[string]string{} +} + +func (res alarmsPageRes) Code() int { + return http.StatusOK +} + +func (res alarmsPageRes) Empty() bool { + return false +} diff --git a/alarms/api/transport.go b/alarms/api/transport.go new file mode 100644 index 000000000..54e2d9118 --- /dev/null +++ b/alarms/api/transport.go @@ -0,0 +1,176 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "math" + "net/http" + "strings" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.Authentication) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + mux := chi.NewRouter() + + mux.Route("/{domainID}/alarms", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(api.RequestIDMiddleware(idp)) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listAlarmsEndpoint(svc), + decodeListAlarmsReq, + api.EncodeResponse, + opts..., + ), "list_alarms").ServeHTTP) + r.Route("/{alarmID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "get_alarm").ServeHTTP) + r.Put("/", otelhttp.NewHandler(kithttp.NewServer( + updateAlarmEndpoint(svc), + decodeUpdateAlarmReq, + api.EncodeResponse, + opts..., + ), "update_alarm").ServeHTTP) + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "delete_alarm").ServeHTTP) + }) + }) + }) + + mux.Get("/health", supermq.Health("alarms", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeListAlarmsReq(_ context.Context, r *http.Request) (interface{}, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + domainID, err := apiutil.ReadStringQuery(r, "domain_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + channelID, err := apiutil.ReadStringQuery(r, "channel_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + clientID, err := apiutil.ReadStringQuery(r, "client_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + subtopic, err := apiutil.ReadStringQuery(r, "subtopic", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + ruleID, err := apiutil.ReadStringQuery(r, "rule_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + s, err := apiutil.ReadStringQuery(r, api.StatusKey, alarms.All) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + status, err := alarms.ToStatus(s) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assigneeID, err := apiutil.ReadStringQuery(r, "assignee_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + serverity, err := apiutil.ReadNumQuery(r, "severity", uint64(math.MaxUint8)) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + updatedBy, err := apiutil.ReadStringQuery(r, "updated_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assignedBy, err := apiutil.ReadStringQuery(r, "assigned_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + acknowledgedBy, err := apiutil.ReadStringQuery(r, "acknowledged_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + resolvedBy, err := apiutil.ReadStringQuery(r, "resolved_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + return listAlarmsReq{ + PageMetadata: alarms.PageMetadata{ + Offset: offset, + Limit: limit, + DomainID: domainID, + ChannelID: channelID, + ClientID: clientID, + Subtopic: subtopic, + RuleID: ruleID, + Status: status, + AssigneeID: assigneeID, + ResolvedBy: resolvedBy, + Severity: uint8(serverity), + UpdatedBy: updatedBy, + AcknowledgedBy: acknowledgedBy, + AssignedBy: assignedBy, + }, + }, nil +} + +func decodeAlarmReq(_ context.Context, r *http.Request) (interface{}, error) { + return alarmReq{ + Alarm: alarms.Alarm{ + ID: chi.URLParam(r, "alarmID"), + }, + }, nil +} + +func decodeUpdateAlarmReq(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return alarmReq{}, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + req := alarmReq{} + if err := json.NewDecoder(r.Body).Decode(&req.Alarm); err != nil { + return alarmReq{}, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err)) + } + + req.Alarm.ID = chi.URLParam(r, "alarmID") + + return req, nil +} diff --git a/alarms/consumer/brokers/brokers_nats.go b/alarms/consumer/brokers/brokers_nats.go new file mode 100644 index 000000000..0f24a26dc --- /dev/null +++ b/alarms/consumer/brokers/brokers_nats.go @@ -0,0 +1,39 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !rabbitmq +// +build !rabbitmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/nats" + "github.com/nats-io/nats.go/jetstream" +) + +const AllTopic = "alarms.>" + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + cfg := jetstream.StreamConfig{ + Name: "alarms", + Description: "SuperMQ stream alarms", + Subjects: []string{"alarms.>"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, + } + pb, err := broker.NewPubSub(ctx, url, logger, broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/consumer/brokers/brokers_rabbitmq.go b/alarms/consumer/brokers/brokers_rabbitmq.go new file mode 100644 index 000000000..cb0960223 --- /dev/null +++ b/alarms/consumer/brokers/brokers_rabbitmq.go @@ -0,0 +1,26 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build rabbitmq +// +build rabbitmq + +package brokers + +import ( + "context" + "log/slog" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/rabbitmq" +) + +const AllTopic = "alarms.#" + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix("alarms")) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/consumer/consumer.go b/alarms/consumer/consumer.go new file mode 100644 index 000000000..8b51979f7 --- /dev/null +++ b/alarms/consumer/consumer.go @@ -0,0 +1,57 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "bytes" + "context" + "encoding/gob" + "log/slog" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" +) + +type handler struct { + svc alarms.Service + logger *slog.Logger +} + +func Newhandler(svc alarms.Service, logger *slog.Logger) messaging.MessageHandler { + return &handler{svc: svc, logger: logger} +} + +func (h handler) Handle(msg *messaging.Message) (err error) { + if msg == nil { + return errors.New("message is empty") + } + if msg.GetPayload() == nil { + return errors.New("message payload is empty") + } + + var alarm alarms.Alarm + if err := gob.NewDecoder(bytes.NewReader(msg.GetPayload())).Decode(&alarm); err != nil { + return err + } + alarm.DomainID = msg.GetDomain() + alarm.ChannelID = msg.GetChannel() + alarm.ClientID = msg.GetPublisher() + alarm.Subtopic = msg.GetSubtopic() + + if alarm.CreatedAt.IsZero() { + alarm.CreatedAt = time.Unix(0, int64(msg.GetCreated())) + } + + if err := alarm.Validate(); err != nil { + return err + } + + return h.svc.CreateAlarm(context.Background(), alarm) +} + +func (h handler) Cancel() error { + return nil +} diff --git a/alarms/doc.go b/alarms/doc.go new file mode 100644 index 000000000..9f7866f33 --- /dev/null +++ b/alarms/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package alarms contains domain concept definitions needed to support +// Alarms service feature, i.e. create, read, update, and delete alarms. +package alarms diff --git a/alarms/middleware/authorization.go b/alarms/middleware/authorization.go new file mode 100644 index 000000000..bad879396 --- /dev/null +++ b/alarms/middleware/authorization.go @@ -0,0 +1,124 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/authn" + smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/policies" +) + +type authorizationMiddleware struct { + svc alarms.Service + authz smqauthz.Authorization +} + +var _ alarms.Service = (*authorizationMiddleware)(nil) + +func NewAuthorizationMiddleware(svc alarms.Service, authz smqauthz.Authorization) alarms.Service { + return &authorizationMiddleware{ + svc: svc, + authz: authz, + } +} + +func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) { + return am.svc.CreateAlarm(ctx, alarm) +} + +func (am *authorizationMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (dba alarms.Alarm, err error) { + // if assignee is present check if assignee is member of domain + + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.AdminPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.Alarm{}, err + } + + if alarm.AssigneeID != "" { + domainUserId := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID) + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: domainUserId, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + }); err != nil { + return alarms.Alarm{}, err + } + } + + return am.svc.UpdateAlarm(ctx, session, alarm) +} + +func (am *authorizationMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + req := smqauthz.PolicyReq{ + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.AdminPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return err + } + + return am.svc.DeleteAlarm(ctx, session, id) +} + +func (am *authorizationMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + if pm.DomainID == "" { + pm.DomainID = session.DomainID + } + + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.AlarmsPage{}, err + } + + return am.svc.ListAlarms(ctx, session, pm) +} + +func (am *authorizationMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.Alarm{}, err + } + + return am.svc.ViewAlarm(ctx, session, id) +} diff --git a/alarms/middleware/doc.go b/alarms/middleware/doc.go new file mode 100644 index 000000000..ce4a296d2 --- /dev/null +++ b/alarms/middleware/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides middleware for the alarms service. +// This is logging, metrics, and tracing middleware. +package middleware diff --git a/alarms/middleware/logging.go b/alarms/middleware/logging.go new file mode 100644 index 000000000..c8e265358 --- /dev/null +++ b/alarms/middleware/logging.go @@ -0,0 +1,151 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-chi/chi/v5/middleware" +) + +type loggingMiddleware struct { + logger *slog.Logger + service alarms.Service +} + +var _ alarms.Service = (*loggingMiddleware)(nil) + +func NewLoggingMiddleware(logger *slog.Logger, service alarms.Service) alarms.Service { + return &loggingMiddleware{ + logger: logger, + service: service, + } +} + +func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("rule_id", alarm.RuleID), + slog.String("domain_id", alarm.DomainID), + slog.String("channel_id", alarm.ChannelID), + slog.String("client_id", alarm.ClientID), + slog.String("subtopic", alarm.Subtopic), + slog.String("measurement", alarm.Measurement), + slog.String("value", alarm.Value), + slog.String("unit", alarm.Unit), + slog.String("threshold", alarm.Threshold), + slog.String("cause", alarm.Cause), + slog.Uint64("severity", uint64(alarm.Severity)), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create alarm failed", args...) + return + } + lm.logger.Info("Create alarm completed successfully", args...) + }(time.Now()) + + return lm.service.CreateAlarm(ctx, alarm) +} + +func (lm *loggingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("id", dba.ID), + slog.String("rule_id", dba.RuleID), + slog.String("domain_id", dba.DomainID), + slog.String("channel_id", dba.ChannelID), + slog.String("client_id", dba.ClientID), + slog.String("subtopic", dba.Subtopic), + slog.String("measurement", dba.Measurement), + slog.String("value", dba.Value), + slog.String("unit", dba.Unit), + slog.String("threshold", dba.Threshold), + slog.String("cause", dba.Cause), + slog.Uint64("severity", uint64(dba.Severity)), + slog.String("status", dba.Status.String()), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update alarm failed", args...) + return + } + lm.logger.Info("Update alarm completed successfully", args...) + }(time.Now()) + + return lm.service.UpdateAlarm(ctx, session, alarm) +} + +func (lm *loggingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View alarm failed", args...) + return + } + lm.logger.Info("View alarm completed successfully", args...) + }(time.Now()) + + return lm.service.ViewAlarm(ctx, session, id) +} + +func (lm *loggingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (dbp alarms.AlarmsPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Int("offset", int(pm.Offset)), + slog.Int("limit", int(pm.Limit)), + slog.String("rule_id", pm.RuleID), + slog.String("domain_id", pm.DomainID), + slog.String("channel_id", pm.ChannelID), + slog.String("client_id", pm.ClientID), + slog.String("subtopic", pm.Subtopic), + slog.Uint64("severity", uint64(pm.Severity)), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List alarms failed", args...) + return + } + lm.logger.Info("List alarms completed successfully", args...) + }(time.Now()) + + return lm.service.ListAlarms(ctx, session, pm) +} + +func (lm *loggingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete alarm failed", args...) + return + } + lm.logger.Info("Delete alarm completed successfully", args...) + }(time.Now()) + + return lm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/metrics.go b/alarms/middleware/metrics.go new file mode 100644 index 000000000..3d99baafc --- /dev/null +++ b/alarms/middleware/metrics.go @@ -0,0 +1,74 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-kit/kit/metrics" +) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + service alarms.Service +} + +var _ alarms.Service = (*metricsMiddleware)(nil) + +func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service alarms.Service) alarms.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + service: service, + } +} + +func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + defer func(begin time.Time) { + mm.counter.With("method", "create_alarm").Add(1) + mm.latency.With("method", "create_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.CreateAlarm(ctx, alarm) +} + +func (mm *metricsMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_alarm").Add(1) + mm.latency.With("method", "update_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateAlarm(ctx, session, alarm) +} + +func (mm *metricsMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_alarm").Add(1) + mm.latency.With("method", "get_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewAlarm(ctx, session, id) +} + +func (mm *metricsMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_alarms").Add(1) + mm.latency.With("method", "list_alarms").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ListAlarms(ctx, session, pm) +} + +func (mm *metricsMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "delete_alarm").Add(1) + mm.latency.With("method", "delete_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/tracing.go b/alarms/middleware/tracing.go new file mode 100644 index 000000000..a969af409 --- /dev/null +++ b/alarms/middleware/tracing.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + smqTracing "github.com/absmach/supermq/pkg/tracing" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +type tracingMiddleware struct { + tracer trace.Tracer + svc alarms.Service +} + +var _ alarms.Service = (*tracingMiddleware)(nil) + +func NewTracingMiddleware(tracer trace.Tracer, svc alarms.Service) alarms.Service { + return &tracingMiddleware{ + tracer: tracer, + svc: svc, + } +} + +func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "create_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.CreateAlarm(ctx, alarm) +} + +func (tm *tracingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.UpdateAlarm(ctx, session, alarm) +} + +func (tm *tracingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "get_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewAlarm(ctx, session, id) +} + +func (tm *tracingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_alarms", trace.WithAttributes( + attribute.Int("offset", int(pm.Offset)), + attribute.Int("limit", int(pm.Limit)), + )) + defer span.End() + + return tm.svc.ListAlarms(ctx, session, pm) +} + +func (tm *tracingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "delete_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/mocks/repository.go b/alarms/mocks/repository.go new file mode 100644 index 000000000..12017d769 --- /dev/null +++ b/alarms/mocks/repository.go @@ -0,0 +1,309 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Repository +func (_mock *Repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Repository_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Repository_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Repository_CreateAlarm_Call { + return &Repository_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Repository_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Repository_CreateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_CreateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Repository +func (_mock *Repository) DeleteAlarm(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Repository_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx +// - id +func (_e *Repository_Expecter) DeleteAlarm(ctx interface{}, id interface{}) *Repository_DeleteAlarm_Call { + return &Repository_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, id)} +} + +func (_c *Repository_DeleteAlarm_Call) Run(run func(ctx context.Context, id string)) *Repository_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) Return(err error) *Repository_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Repository_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAlarms provides a mock function for the type Repository +func (_mock *Repository) ListAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type Repository_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx +// - pm +func (_e *Repository_Expecter) ListAlarms(ctx interface{}, pm interface{}) *Repository_ListAlarms_Call { + return &Repository_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, pm)} +} + +func (_c *Repository_ListAlarms_Call) Run(run func(ctx context.Context, pm alarms.PageMetadata)) *Repository_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.PageMetadata)) + }) + return _c +} + +func (_c *Repository_ListAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Repository_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Repository +func (_mock *Repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Repository_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Repository_Expecter) UpdateAlarm(ctx interface{}, alarm interface{}) *Repository_UpdateAlarm_Call { + return &Repository_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, alarm)} +} + +func (_c *Repository_UpdateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Repository +func (_mock *Repository) ViewAlarm(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarmID, domainID) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarmID, domainID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarmID, domainID) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, alarmID, domainID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Repository_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx +// - alarmID +// - domainID +func (_e *Repository_Expecter) ViewAlarm(ctx interface{}, alarmID interface{}, domainID interface{}) *Repository_ViewAlarm_Call { + return &Repository_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, alarmID, domainID)} +} + +func (_c *Repository_ViewAlarm_Call) Run(run func(ctx context.Context, alarmID string, domainID string)) *Repository_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *Repository_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Repository_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Repository_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error)) *Repository_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/mocks/service.go b/alarms/mocks/service.go new file mode 100644 index 000000000..6e984cc47 --- /dev/null +++ b/alarms/mocks/service.go @@ -0,0 +1,304 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Service +func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) error); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Service_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Service_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Service_CreateAlarm_Call { + return &Service_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Service_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Service_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Service_CreateAlarm_Call) Return(err error) *Service_CreateAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) error) *Service_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Service +func (_mock *Service) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Service_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx +// - session +// - id +func (_e *Service_Expecter) DeleteAlarm(ctx interface{}, session interface{}, id interface{}) *Service_DeleteAlarm_Call { + return &Service_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, session, id)} +} + +func (_c *Service_DeleteAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(string)) + }) + return _c +} + +func (_c *Service_DeleteAlarm_Call) Return(err error) *Service_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAlarms provides a mock function for the type Service +func (_mock *Service) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type Service_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx +// - session +// - pm +func (_e *Service_Expecter) ListAlarms(ctx interface{}, session interface{}, pm interface{}) *Service_ListAlarms_Call { + return &Service_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, session, pm)} +} + +func (_c *Service_ListAlarms_Call) Run(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata)) *Service_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(alarms.PageMetadata)) + }) + return _c +} + +func (_c *Service_ListAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Service_ListAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Service_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Service_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Service +func (_mock *Service) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, session, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Service_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx +// - session +// - alarm +func (_e *Service_Expecter) UpdateAlarm(ctx interface{}, session interface{}, alarm interface{}) *Service_UpdateAlarm_Call { + return &Service_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, session, alarm)} +} + +func (_c *Service_UpdateAlarm_Call) Run(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm)) *Service_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(alarms.Alarm)) + }) + return _c +} + +func (_c *Service_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Service_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Service_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error)) *Service_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Service +func (_mock *Service) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Service_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx +// - session +// - id +func (_e *Service_Expecter) ViewAlarm(ctx interface{}, session interface{}, id interface{}) *Service_ViewAlarm_Call { + return &Service_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, session, id)} +} + +func (_c *Service_ViewAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(string)) + }) + return _c +} + +func (_c *Service_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Service_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Service_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error)) *Service_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/postgres/alarms.go b/alarms/postgres/alarms.go new file mode 100644 index 000000000..ddd85e127 --- /dev/null +++ b/alarms/postgres/alarms.go @@ -0,0 +1,427 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "strings" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" +) + +type repository struct { + db *sqlx.DB +} + +var _ alarms.Repository = (*repository)(nil) + +func NewAlarmsRepo(db *sqlx.DB) alarms.Repository { + return &repository{db: db} +} + +func (r *repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + query := `INSERT INTO alarms (id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, value, unit, threshold, cause, status, severity, assignee_id, metadata, created_at) + VALUES (:id, :rule_id, :domain_id, :channel_id, :client_id, :subtopic, :measurement, :value, :unit, :threshold, :cause, :status, :severity, :assignee_id, :metadata, :created_at) + RETURNING id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, value, unit, threshold, cause, status, severity, assignee_id, metadata, created_at;` + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, query, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + var query []string + var upq string + if alarm.Status != 0 { + query = append(query, "status = :status,") + } + if alarm.AssigneeID != "" { + query = append(query, "assignee_id = :assignee_id,") + } + if !alarm.AssignedAt.IsZero() { + query = append(query, "assigned_at = :assigned_at,") + } + if alarm.AssignedBy != "" { + query = append(query, "assigned_by = :assigned_by,") + } + if alarm.AcknowledgedBy != "" { + query = append(query, "acknowledged_by = :acknowledged_by,") + } + if !alarm.AcknowledgedAt.IsZero() { + query = append(query, "acknowledged_at = :acknowledged_at,") + } + if alarm.ResolvedBy != "" { + query = append(query, "resolved_by = :resolved_by,") + } + if !alarm.ResolvedAt.IsZero() { + query = append(query, "resolved_at = :resolved_at,") + } + if alarm.Metadata != nil { + query = append(query, "metadata = :metadata,") + } + if len(query) > 0 { + upq = strings.Join(query, " ") + } + + q := fmt.Sprintf(`UPDATE alarms SET %s updated_by = :updated_by, updated_at = :updated_at WHERE id = :id + RETURNING id, rule_id, measurement, value, unit, cause, status, domain_id, assignee_id, metadata, created_at, updated_by, updated_at, resolved_by, resolved_at;`, upq) + + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, q, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) ViewAlarm(ctx context.Context, alarmID, domainID string) (alarms.Alarm, error) { + query := `SELECT * FROM alarms WHERE id = :id AND domain_id = :domain_id;` + row, err := r.db.NamedQueryContext(ctx, query, map[string]interface{}{ + "id": alarmID, "domain_id": domainID, + }) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba := dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + alarm, err := toAlarm(dba) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarm, nil +} + +func (r *repository) ListAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + query, err := pageQuery(pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + q := fmt.Sprintf(`SELECT * FROM alarms %s ORDER BY created_at DESC LIMIT :limit OFFSET :offset;`, query) + rows, err := r.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []alarms.Alarm + for rows.Next() { + dba := dbAlarm{} + if err := rows.StructScan(&dba); err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + a, err := toAlarm(dba) + if err != nil { + return alarms.AlarmsPage{}, err + } + + items = append(items, a) + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM alarms %s;`, query) + total, err := postgres.Total(ctx, r.db, q, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarms.AlarmsPage{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Alarms: items, + }, nil +} + +func (r *repository) DeleteAlarm(ctx context.Context, id string) error { + query := `DELETE FROM alarms WHERE id = :id;` + result, err := r.db.NamedExecContext(ctx, query, map[string]interface{}{"id": id}) + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if rowsAffected == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +type dbAlarm struct { + ID string `db:"id"` + RuleID string `db:"rule_id"` + DomainID string `db:"domain_id"` + ChannelID string `db:"channel_id"` + ClientID string `db:"client_id"` + Subtopic string `db:"subtopic"` + Measurement string `db:"measurement"` + Value string `db:"value"` + Unit string `db:"unit"` + Cause string `db:"cause"` + Threshold string `db:"threshold"` + Status alarms.Status `db:"status"` + Severity uint8 `db:"severity"` + AssigneeID string `db:"assignee_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy *string `db:"updated_by,omitempty"` + AssignedAt sql.NullTime `db:"assigned_at,omitempty"` + AssignedBy *string `db:"assigned_by,omitempty"` + AcknowledgedAt sql.NullTime `db:"acknowledged_at,omitempty"` + AcknowledgedBy *string `db:"acknowledged_by,omitempty"` + ResolvedAt sql.NullTime `db:"resolved_at,omitempty"` + ResolvedBy *string `db:"resolved_by,omitempty"` + Metadata []byte `db:"metadata,omitempty"` +} + +func toDBAlarm(a alarms.Alarm) (dbAlarm, error) { + if a.CreatedAt.IsZero() { + a.CreatedAt = time.Now() + } + var updatedBy *string + if a.UpdatedBy != "" { + updatedBy = &a.UpdatedBy + } + var updatedAt sql.NullTime + if a.UpdatedAt != (time.Time{}) { + updatedAt = sql.NullTime{Time: a.UpdatedAt, Valid: true} + } + + var acknowledgedBy *string + if a.AcknowledgedBy != "" { + acknowledgedBy = &a.AcknowledgedBy + } + var acknowledgedAt sql.NullTime + if a.AcknowledgedAt != (time.Time{}) { + acknowledgedAt = sql.NullTime{Time: a.AcknowledgedAt, Valid: true} + } + + var resolvedBy *string + if a.ResolvedBy != "" { + resolvedBy = &a.ResolvedBy + } + var resolvedAt sql.NullTime + if a.ResolvedAt != (time.Time{}) { + resolvedAt = sql.NullTime{Time: a.ResolvedAt, Valid: true} + } + + var assignedBy *string + if a.AssignedBy != "" { + assignedBy = &a.AssignedBy + } + var assignedAt sql.NullTime + if a.AssignedAt != (time.Time{}) { + assignedAt = sql.NullTime{Time: a.AssignedAt, Valid: true} + } + + metadata := []byte("{}") + if len(a.Metadata) > 0 { + b, err := json.Marshal(a.Metadata) + if err != nil { + return dbAlarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + metadata = b + } + + return dbAlarm{ + ID: a.ID, + RuleID: a.RuleID, + DomainID: a.DomainID, + ChannelID: a.ChannelID, + ClientID: a.ClientID, + Subtopic: a.Subtopic, + Measurement: a.Measurement, + Value: a.Value, + Unit: a.Unit, + Cause: a.Cause, + Threshold: a.Threshold, + Status: a.Status, + Severity: a.Severity, + AssigneeID: a.AssigneeID, + CreatedAt: a.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func toAlarm(dbr dbAlarm) (alarms.Alarm, error) { + var updatedBy string + if dbr.UpdatedBy != nil { + updatedBy = *dbr.UpdatedBy + } + var updatedAt time.Time + if dbr.UpdatedAt.Valid { + updatedAt = dbr.UpdatedAt.Time + } + + var assignedBy string + if dbr.AssignedBy != nil { + assignedBy = *dbr.AssignedBy + } + var assignedAt time.Time + if dbr.AssignedAt.Valid { + assignedAt = dbr.AssignedAt.Time + } + + var acknowledgedBy string + if dbr.AcknowledgedBy != nil { + acknowledgedBy = *dbr.AcknowledgedBy + } + var acknowledgedAt time.Time + if dbr.AcknowledgedAt.Valid { + acknowledgedAt = dbr.AcknowledgedAt.Time + } + + var resolvedBy string + if dbr.ResolvedBy != nil { + resolvedBy = *dbr.ResolvedBy + } + var resolvedAt time.Time + if dbr.ResolvedAt.Valid { + resolvedAt = dbr.ResolvedAt.Time + } + + var metadata map[string]interface{} + if len(dbr.Metadata) > 0 { + err := json.Unmarshal(dbr.Metadata, &metadata) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + } + + return alarms.Alarm{ + ID: dbr.ID, + RuleID: dbr.RuleID, + DomainID: dbr.DomainID, + ChannelID: dbr.ChannelID, + ClientID: dbr.ClientID, + Subtopic: dbr.Subtopic, + Measurement: dbr.Measurement, + Value: dbr.Value, + Unit: dbr.Unit, + Threshold: dbr.Threshold, + Cause: dbr.Cause, + Status: dbr.Status, + Severity: dbr.Severity, + AssigneeID: dbr.AssigneeID, + CreatedAt: dbr.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func pageQuery(pm alarms.PageMetadata) (string, error) { + var query []string + if pm.DomainID != "" { + query = append(query, "domain_id = :domain_id") + } + if pm.ChannelID != "" { + query = append(query, "channel_id = :channel_id") + } + if pm.ClientID != "" { + query = append(query, "client_id = :client_id") + } + if pm.Subtopic != "" { + query = append(query, "subtopic = :subtopic") + } + if pm.RuleID != "" { + query = append(query, "rule_id = :rule_id") + } + if pm.Status != alarms.AllStatus { + query = append(query, "status = :status") + } + if pm.AssigneeID != "" { + query = append(query, "assignee_id = :assignee_id") + } + if pm.Severity != math.MaxUint8 { + query = append(query, "severity = :severity") + } + if pm.UpdatedBy != "" { + query = append(query, "updated_by = :updated_by") + } + if pm.ResolvedBy != "" { + query = append(query, "resolved_by = :resolved_by") + } + if pm.AcknowledgedBy != "" { + query = append(query, "acknowledged_by = :acknowledged_by") + } + if pm.AssignedBy != "" { + query = append(query, "assigned_by = :assigned_by") + } + + var emq string + if len(query) > 0 { + emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return emq, nil +} diff --git a/alarms/postgres/alarms_test.go b/alarms/postgres/alarms_test.go new file mode 100644 index 000000000..b435da524 --- /dev/null +++ b/alarms/postgres/alarms_test.go @@ -0,0 +1,474 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func TestCreateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarm, + err: nil, + }, + { + desc: "duplicate alarm", + alarm: alarm, + err: repoerr.ErrConflict, + }, + { + desc: "missing rule id", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + + Metadata: map[string]interface{}{ + "key": make(chan int), + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.CreateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.alarm.RuleID, alarm.RuleID) + require.Equal(t, tc.alarm.Measurement, alarm.Measurement) + require.Equal(t, tc.alarm.Value, alarm.Value) + require.Equal(t, tc.alarm.Unit, alarm.Unit) + require.Equal(t, tc.alarm.Cause, alarm.Cause) + require.Equal(t, tc.alarm.Status, alarm.Status) + require.Equal(t, tc.alarm.DomainID, alarm.DomainID) + require.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + require.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + Status: alarms.ActiveStatus, + DomainID: alarm.DomainID, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: alarm.CreatedAt, + UpdatedAt: time.Now().Local(), + UpdatedBy: generateUUID(&testing.T{}), + ResolvedAt: time.Now().Local(), + ResolvedBy: generateUUID(&testing.T{}), + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + RuleID: generateUUID(&testing.T{}), + Status: 0, + DomainID: generateUUID(&testing.T{}), + AssigneeID: strings.Repeat("a", 40), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.UpdateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.alarm.Status, alarm.Status) + require.Equal(t, tc.alarm.DomainID, alarm.DomainID) + require.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + require.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestViewAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + domainID: alarm.DomainID, + err: nil, + }, + { + desc: "non existing alarm id", + id: generateUUID(&testing.T{}), + domainID: alarm.DomainID, + err: repoerr.ErrNotFound, + }, + { + desc: "non existing domain id", + id: alarm.ID, + domainID: generateUUID(&testing.T{}), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.ViewAlarm(context.Background(), tc.id, tc.domainID) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.id, alarm.ID) + }) + } +} + +func TestListAlarms(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + repo := postgres.NewAlarmsRepo(db) + items := make([]alarms.Alarm, 1000) + for i := range 1000 { + items[i] = alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), items[i]) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + items[i].ID = alarm.ID + } + + cases := []struct { + desc string + pm alarms.PageMetadata + response []alarms.Alarm + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + response: items[:10], + err: nil, + }, + { + desc: "offset and limit", + pm: alarms.PageMetadata{ + Offset: 10, + Limit: 50, + }, + response: items[10:60], + err: nil, + }, + { + desc: "empty page", + pm: alarms.PageMetadata{}, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid page", + pm: alarms.PageMetadata{ + Offset: 1000, + Limit: 10, + }, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid assignee id", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + AssigneeID: generateUUID(&testing.T{}), + }, + response: []alarms.Alarm{}, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarms, err := repo.ListAlarms(context.Background(), tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, len(tc.response), len(alarms.Alarms)) + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + err: nil, + }, + { + desc: "non existing alarm", + id: generateUUID(&testing.T{}), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.DeleteAlarm(context.Background(), tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} + +func generateUUID(t *testing.T) string { + ulid, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + return ulid +} diff --git a/alarms/postgres/init.go b/alarms/postgres/init.go new file mode 100644 index 000000000..6ccb3ab44 --- /dev/null +++ b/alarms/postgres/init.go @@ -0,0 +1,52 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + migrate "github.com/rubenv/sql-migrate" +) + +// Migration of Users service. +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "alarms_01", + // VARCHAR(36) for columns with IDs as UUIDS have a maximum of 36 characters + Up: []string{ + `CREATE TABLE IF NOT EXISTS alarms ( + id VARCHAR(36) PRIMARY KEY, + rule_id VARCHAR(36) NOT NULL CHECK (length(rule_id) > 0), + domain_id VARCHAR(36) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + client_id VARCHAR(36) NOT NULL, + subtopic TEXT NOT NULL, + measurement TEXT NOT NULL, + value TEXT NOT NULL, + unit TEXT NOT NULL, + threshold TEXT NOT NULL, + cause TEXT NOT NULL, + status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0), + severity SMALLINT NOT NULL DEFAULT 0 CHECK (severity >= 0), + assignee_id VARCHAR(36), + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NULL, + updated_by VARCHAR(36) NULL, + assigned_at TIMESTAMPTZ NULL, + assigned_by VARCHAR(36) NULL, + acknowledged_at TIMESTAMPTZ NULL, + acknowledged_by VARCHAR(36) NULL, + resolved_at TIMESTAMPTZ NULL, + resolved_by VARCHAR(36) NULL, + metadata JSONB + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS alarms`, + }, + }, + }, + } +} diff --git a/alarms/postgres/setup_test.go b/alarms/postgres/setup_test.go new file mode 100644 index 000000000..3452d75ef --- /dev/null +++ b/alarms/postgres/setup_test.go @@ -0,0 +1,93 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + apostgres "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = postgres.Setup(dbConfig, *apostgres.Migration()); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/alarms/service.go b/alarms/service.go new file mode 100644 index 000000000..c4bf28304 --- /dev/null +++ b/alarms/service.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "time" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/pkg/authn" +) + +type service struct { + idp supermq.IDProvider + repo Repository +} + +var _ Service = (*service)(nil) + +func NewService(idp supermq.IDProvider, repo Repository) Service { + return &service{ + idp: idp, + repo: repo, + } +} + +func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error { + id, err := s.idp.ID() + if err != nil { + return err + } + alarm.ID = id + if alarm.CreatedAt.IsZero() { + alarm.CreatedAt = time.Now() + } + + if err := alarm.Validate(); err != nil { + return err + } + + pm := PageMetadata{ + Limit: 1, + Offset: 0, + DomainID: alarm.DomainID, + ChannelID: alarm.ChannelID, + ClientID: alarm.ClientID, + Subtopic: alarm.Subtopic, + RuleID: alarm.RuleID, + Severity: alarm.Severity, + Status: alarm.Status, + } + lastAlarms, err := s.repo.ListAlarms(ctx, pm) + if err != nil { + return err + } + + if len(lastAlarms.Alarms) > 0 { + return nil + } + + _, err = s.repo.CreateAlarm(ctx, alarm) + + return err +} + +func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID string) (Alarm, error) { + return s.repo.ViewAlarm(ctx, alarmID, session.DomainID) +} + +func (s *service) ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) { + return s.repo.ListAlarms(ctx, pm) +} + +func (s *service) DeleteAlarm(ctx context.Context, session authn.Session, alarmID string) error { + return s.repo.DeleteAlarm(ctx, alarmID) +} + +func (s *service) UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) { + alarm.UpdatedAt = time.Now() + alarm.UpdatedBy = session.UserID + + return s.repo.UpdateAlarm(ctx, alarm) +} diff --git a/alarms/service_test.go b/alarms/service_test.go new file mode 100644 index 000000000..aba72d339 --- /dev/null +++ b/alarms/service_test.go @@ -0,0 +1,253 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "context" + "fmt" + "testing" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/alarms/mocks" + "github.com/absmach/magistrala/pkg/errors" + "github.com/absmach/supermq/pkg/authn" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var idp = uuid.New() + +func TestCreateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("rule_id is required"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("CreateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + repoCall1 := repo.On("ListAlarms", context.Background(), alarms.PageMetadata{Offset: 0, Limit: 1}).Return(alarms.AlarmsPage{}, tc.err) + err := svc.CreateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestViewAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + domainID: "domain-id", + err: nil, + }, + { + desc: "non existing alarm id", + id: "alarm-id", + domainID: "domain-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.domainID} + repoCall := repo.On("ViewAlarm", context.Background(), tc.id, tc.domainID).Return(alarms.Alarm{}, tc.err) + _, err := svc.ViewAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.alarm.DomainID} + repoCall := repo.On("UpdateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + _, err := svc.UpdateAlarm(context.Background(), s, tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestListAlarms(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + pm alarms.PageMetadata + page alarms.AlarmsPage + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + page: alarms.AlarmsPage{ + Offset: 0, + Limit: 10, + Total: 10, + Alarms: []alarms.Alarm{}, + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.pm.DomainID} + repoCall := repo.On("ListAlarms", context.Background(), tc.pm).Return(tc.page, tc.err) + _, err := svc.ListAlarms(context.Background(), s, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + err: nil, + }, + { + desc: "non existing alarm", + id: "alarm-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.id} + repoCall := repo.On("DeleteAlarm", context.Background(), tc.id).Return(tc.err) + err := svc.DeleteAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} diff --git a/alarms/status.go b/alarms/status.go new file mode 100644 index 000000000..90a6f7169 --- /dev/null +++ b/alarms/status.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +type Status uint8 + +const ( + ActiveStatus Status = iota + ClearedStatus + + // AllStatus is used for querying purposes to list alarms irrespective + // of their status. It is never stored in the database as the actual + // Alarm status and should always be the largest value in this enumeration. + AllStatus +) + +const ( + Active = "active" + Cleared = "cleared" + Unknown = "unknown" + All = "all" +) + +// String converts alarm status to string literal. +func (s Status) String() string { + switch s { + case ActiveStatus: + return Active + case ClearedStatus: + return Cleared + default: + return Unknown + } +} + +// ToStatus converts string value to a valid Alarm status. +func ToStatus(status string) (Status, error) { + switch status { + case Active: + return ActiveStatus, nil + case Cleared: + return ClearedStatus, nil + case All: + return AllStatus, nil + default: + return Status(0), svcerr.ErrInvalidStatus + } +} + +// Custom Marshaller for Alarm. +func (s Status) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +// Custom Unmarshaler for Alarm. +func (s *Status) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToStatus(str) + *s = val + + return err +} diff --git a/bootstrap/events/producer/streams.go b/bootstrap/events/producer/streams.go index 750550854..5ea708bcf 100644 --- a/bootstrap/events/producer/streams.go +++ b/bootstrap/events/producer/streams.go @@ -14,22 +14,21 @@ import ( var _ bootstrap.Service = (*eventStore)(nil) const ( - streamPrefix = ".bootstrap" - addStream = streamPrefix + "add" - viewStream = streamPrefix + "view" - updateStream = streamPrefix + "update" - listStream = streamPrefix + "list" - bootstrapStream = streamPrefix + "bootstrap" - removeStream = streamPrefix + "remove" - updateCertStream = streamPrefix + "update_cert" - updateConnectionsStream = streamPrefix + "update_connections" - changeStateStream = streamPrefix + "change_state" - - connectClientHandlerStream = streamPrefix + "connect_client_handler" - disconnectClientHandlerStream = streamPrefix + "disconnect_client_handler" - removeConfigHandlerStream = streamPrefix + "remove_config_handler" - removeChannelHandlerStream = streamPrefix + "remove_channel_handler" - updateChannelHandlerStream = streamPrefix + "update_channel_handler" + magistralaPrefix = "magistrala." + createStream = magistralaPrefix + configCreate + viewStream = magistralaPrefix + configView + listStream = magistralaPrefix + configList + updateStream = magistralaPrefix + configUpdate + removeStream = magistralaPrefix + configRemove + updateCertStream = magistralaPrefix + certUpdate + updateConnectionsStream = magistralaPrefix + clientUpdateConnections + removeHandlerStream = magistralaPrefix + configHandlerRemove + bootstrapStream = magistralaPrefix + clientBootstrap + stateChangeStream = magistralaPrefix + clientStateChange + connectStream = magistralaPrefix + clientConnect + disconnectStream = magistralaPrefix + clientDisconnect + updateHandlerStream = magistralaPrefix + channelUpdateHandler + removeChannelHandlerStream = magistralaPrefix + channelHandlerRemove ) type eventStore struct { @@ -56,7 +55,7 @@ func (es *eventStore) Add(ctx context.Context, session smqauthn.Session, token s saved, configCreate, } - if err := es.Publish(ctx, addStream, ev); err != nil { + if err := es.Publish(ctx, createStream, ev); err != nil { return saved, err } @@ -72,7 +71,7 @@ func (es *eventStore) View(ctx context.Context, session smqauthn.Session, id str cfg, configView, } - if err := es.Publish(ctx, viewStream, ev); err != nil { + if err := es.Publish(ctx, configView, ev); err != nil { return cfg, err } @@ -88,7 +87,7 @@ func (es *eventStore) Update(ctx context.Context, session smqauthn.Session, cfg cfg, configUpdate, } - return es.Publish(ctx, updateStream, ev) + return es.Publish(ctx, configUpdate, ev) } func (es eventStore) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { @@ -186,7 +185,7 @@ func (es *eventStore) ChangeState(ctx context.Context, session smqauthn.Session, state: state, } - return es.Publish(ctx, changeStateStream, ev) + return es.Publish(ctx, stateChangeStream, ev) } func (es *eventStore) RemoveConfigHandler(ctx context.Context, id string) error { @@ -199,7 +198,7 @@ func (es *eventStore) RemoveConfigHandler(ctx context.Context, id string) error operation: configHandlerRemove, } - return es.Publish(ctx, removeConfigHandlerStream, ev) + return es.Publish(ctx, removeHandlerStream, ev) } func (es *eventStore) RemoveChannelHandler(ctx context.Context, id string) error { @@ -224,7 +223,7 @@ func (es *eventStore) UpdateChannelHandler(ctx context.Context, channel bootstra channel, } - return es.Publish(ctx, updateChannelHandlerStream, ev) + return es.Publish(ctx, updateStream, ev) } func (es *eventStore) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { @@ -237,7 +236,7 @@ func (es *eventStore) ConnectClientHandler(ctx context.Context, channelID, clien channelID: channelID, } - return es.Publish(ctx, connectClientHandlerStream, ev) + return es.Publish(ctx, connectStream, ev) } func (es *eventStore) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { @@ -250,5 +249,5 @@ func (es *eventStore) DisconnectClientHandler(ctx context.Context, channelID, cl channelID: channelID, } - return es.Publish(ctx, disconnectClientHandlerStream, ev) + return es.Publish(ctx, disconnectStream, ev) } diff --git a/bootstrap/events/producer/streams_test.go b/bootstrap/events/producer/streams_test.go index cf9d43d0d..fe83cea10 100644 --- a/bootstrap/events/producer/streams_test.go +++ b/bootstrap/events/producer/streams_test.go @@ -192,7 +192,7 @@ func TestAdd(t *testing.T) { lastID := "0" for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := tv.sdk.On("Client", tc.config.ClientID, tc.domainID, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, errors.NewSDKError(tc.clientErr)) + sdkCall := tv.sdk.On("Client", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, errors.NewSDKError(tc.clientErr)) repoCall := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything).Return(tc.config.Channels, tc.listErr) repoCall1 := tv.boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) @@ -475,7 +475,7 @@ func TestUpdateConnections(t *testing.T) { lastID := "0" for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := tv.sdk.On("Channel", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + sdkCall := tv.sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.configID).Return(config, tc.retrieveErr) repoCall1 := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything, mock.Anything).Return(config.Channels, tc.listErr) repoCall2 := tv.boot.On("UpdateConnections", context.Background(), tc.domainID, tc.configID, mock.Anything, tc.connections).Return(tc.updateErr) @@ -1054,7 +1054,7 @@ func TestChangeState(t *testing.T) { for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(config, tc.retrieveErr) - sdkCall1 := tv.sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.NewSDKError(tc.connectErr)) + sdkCall1 := tv.sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.NewSDKError(tc.connectErr)) repoCall1 := tv.boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) err := tv.svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) diff --git a/bootstrap/service.go b/bootstrap/service.go index 9e27d70f2..ab3709db9 100644 --- a/bootstrap/service.go +++ b/bootstrap/service.go @@ -144,6 +144,7 @@ func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, to if err != nil { return Config{}, errors.Wrap(errCheckChannels, err) } + cfg.Channels, err = bs.connectionChannels(ctx, toConnect, bs.toIDList(existing), session.DomainID, token) if err != nil { return Config{}, errors.Wrap(errConnectionChannels, err) diff --git a/bootstrap/service_test.go b/bootstrap/service_test.go index 9ba5b5984..48b78ed4b 100644 --- a/bootstrap/service_test.go +++ b/bootstrap/service_test.go @@ -152,9 +152,9 @@ func TestAdd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} - repoCall := sdk.On("Client", tc.config.ClientID, mock.Anything, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, tc.clientErr) - repoCall1 := sdk.On("CreateClient", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Client{}, tc.createClientErr) - repoCall2 := sdk.On("DeleteClient", tc.config.ClientID, tc.domainID, tc.token).Return(tc.deleteClientErr) + repoCall := sdk.On("Client", mock.Anything, tc.config.ClientID, mock.Anything, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, tc.clientErr) + repoCall1 := sdk.On("CreateClient", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Client{}, tc.createClientErr) + repoCall2 := sdk.On("DeleteClient", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(tc.deleteClientErr) repoCall3 := boot.On("ListExisting", context.Background(), tc.domainID, mock.Anything).Return(tc.config.Channels, tc.listExistingErr) repoCall4 := boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) _, err := svc.Add(context.Background(), tc.session, tc.token, tc.config) @@ -434,7 +434,7 @@ func TestUpdateConnections(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := sdk.On("Channel", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + sdkCall := sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) repoCall1 := boot.On("ListExisting", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(c.Channels, tc.listErr) repoCall2 := boot.On("UpdateConnections", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.updateErr) @@ -939,7 +939,7 @@ func TestChangeState(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) - sdkCall := sdk.On("ConnectClients", mock.Anything, mock.Anything, []string{"Publish", "Subscribe"}, mock.Anything, tc.token).Return(tc.connectErr) + sdkCall := sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, []string{"Publish", "Subscribe"}, mock.Anything, tc.token).Return(tc.connectErr) repoCall1 := boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) err := svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) diff --git a/cli/bootstrap.go b/cli/bootstrap.go index 69daac7a6..1acd80cd8 100644 --- a/cli/bootstrap.go +++ b/cli/bootstrap.go @@ -27,7 +27,7 @@ var cmdBootstrap = []cobra.Command{ return } - id, err := sdk.AddBootstrap(cfg, args[1], args[2]) + id, err := sdk.AddBootstrap(cmd.Context(), cfg, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -54,7 +54,7 @@ var cmdBootstrap = []cobra.Command{ Name: Name, } if args[0] == "all" { - l, err := sdk.Bootstraps(pageMetadata, args[1], args[2]) + l, err := sdk.Bootstraps(cmd.Context(), pageMetadata, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -63,7 +63,7 @@ var cmdBootstrap = []cobra.Command{ return } - c, err := sdk.ViewBootstrap(args[0], args[1], args[2]) + c, err := sdk.ViewBootstrap(cmd.Context(), args[0], args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -92,7 +92,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.UpdateBootstrap(cfg, args[1], args[2]); err != nil { + if err := sdk.UpdateBootstrap(cmd.Context(), cfg, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -106,7 +106,7 @@ var cmdBootstrap = []cobra.Command{ logErrorCmd(*cmd, err) return } - if err := sdk.UpdateBootstrapConnection(args[1], ids, args[3], args[4]); err != nil { + if err := sdk.UpdateBootstrapConnection(cmd.Context(), args[1], ids, args[3], args[4]); err != nil { logErrorCmd(*cmd, err) return } @@ -115,7 +115,7 @@ var cmdBootstrap = []cobra.Command{ return } if args[0] == "certs" { - cfg, err := sdk.UpdateBootstrapCerts(args[0], args[1], args[2], args[3], args[4], args[5]) + cfg, err := sdk.UpdateBootstrapCerts(cmd.Context(), args[0], args[1], args[2], args[3], args[4], args[5]) if err != nil { logErrorCmd(*cmd, err) return @@ -137,7 +137,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.RemoveBootstrap(args[0], args[1], args[2]); err != nil { + if err := sdk.RemoveBootstrap(cmd.Context(), args[0], args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -156,7 +156,7 @@ var cmdBootstrap = []cobra.Command{ return } if args[0] == "secure" { - c, err := sdk.BootstrapSecure(args[1], args[2], args[3]) + c, err := sdk.BootstrapSecure(cmd.Context(), args[1], args[2], args[3]) if err != nil { logErrorCmd(*cmd, err) return @@ -165,7 +165,7 @@ var cmdBootstrap = []cobra.Command{ logJSONCmd(*cmd, c) return } - c, err := sdk.Bootstrap(args[0], args[1]) + c, err := sdk.Bootstrap(cmd.Context(), args[0], args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -190,7 +190,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.Whitelist(cfg.ClientID, cfg.State, args[1], args[2]); err != nil { + if err := sdk.Whitelist(cmd.Context(), cfg.ClientID, cfg.State, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cli/bootstrap_test.go b/cli/bootstrap_test.go index f549c9b01..3b28677cd 100644 --- a/cli/bootstrap_test.go +++ b/cli/bootstrap_test.go @@ -102,7 +102,7 @@ func TestCreateBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("AddBootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.id, tc.sdkErr) + sdkCall := sdkMock.On("AddBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.id, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) switch tc.logType { @@ -199,8 +199,8 @@ func TestGetBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("ViewBootstrap", tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr) - sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr) + sdkCall := sdkMock.On("ViewBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) @@ -284,7 +284,7 @@ func TestRemoveBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("RemoveBootstrap", tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr) + sdkCall := sdkMock.On("RemoveBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) switch tc.logType { @@ -443,9 +443,9 @@ func TestUpdateBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { var boot mgsdk.BootstrapConfig - sdkCall := sdkMock.On("UpdateBootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) - sdkCall1 := sdkMock.On("UpdateBootstrapConnection", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) - sdkCall2 := sdkMock.On("UpdateBootstrapCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall := sdkMock.On("UpdateBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall1 := sdkMock.On("UpdateBootstrapConnection", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall2 := sdkMock.On("UpdateBootstrapCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{updCmd}, tc.args...)...) switch tc.logType { @@ -527,7 +527,7 @@ func TestWhitelistConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("Whitelist", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.sdkErr) + sdkCall := sdkMock.On("Whitelist", mock.Anything, mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{whitelistCmd}, tc.args...)...) switch tc.logType { case okLog: @@ -613,8 +613,8 @@ func TestBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("BootstrapSecure", mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) - sdkCall1 := sdkMock.On("Bootstrap", mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall := sdkMock.On("BootstrapSecure", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{bootStrapCmd}, tc.args...)...) switch tc.logType { case entityLog: diff --git a/cli/consumers.go b/cli/consumers.go index 64bcd6009..99ef471df 100644 --- a/cli/consumers.go +++ b/cli/consumers.go @@ -19,7 +19,7 @@ var cmdSubscription = []cobra.Command{ return } - id, err := sdk.CreateSubscription(args[0], args[1], args[2]) + id, err := sdk.CreateSubscription(cmd.Context(), args[0], args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -46,7 +46,7 @@ var cmdSubscription = []cobra.Command{ Contact: Contact, } if args[0] == "all" { - sub, err := sdk.ListSubscriptions(pageMetadata, args[1]) + sub, err := sdk.ListSubscriptions(cmd.Context(), pageMetadata, args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -55,7 +55,7 @@ var cmdSubscription = []cobra.Command{ return } - c, err := sdk.ViewSubscription(args[0], args[1]) + c, err := sdk.ViewSubscription(cmd.Context(), args[0], args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -74,7 +74,7 @@ var cmdSubscription = []cobra.Command{ return } - if err := sdk.DeleteSubscription(args[0], args[1]); err != nil { + if err := sdk.DeleteSubscription(cmd.Context(), args[0], args[1]); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cli/consumers_test.go b/cli/consumers_test.go index 0cb27696c..09a04f485 100644 --- a/cli/consumers_test.go +++ b/cli/consumers_test.go @@ -81,7 +81,7 @@ func TestCreateSubscriptionCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("CreateSubscription", tc.args[0], tc.args[1], tc.args[2]).Return(tc.id, tc.sdkErr) + sdkCall := sdkMock.On("CreateSubscription", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.id, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) switch tc.logType { @@ -168,8 +168,8 @@ func TestGetSubscriptionsCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("ViewSubscription", tc.args[0], tc.args[1]).Return(tc.subscription, tc.sdkErr) - sdkCall1 := sdkMock.On("ListSubscriptions", mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr) + sdkCall := sdkMock.On("ViewSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.subscription, tc.sdkErr) + sdkCall1 := sdkMock.On("ListSubscriptions", mock.Anything, mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) @@ -249,7 +249,7 @@ func TestRemoveSubscriptionCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("DeleteSubscription", tc.args[0], tc.args[1]).Return(tc.sdkErr) + sdkCall := sdkMock.On("DeleteSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) switch tc.logType { diff --git a/cli/provision.go b/cli/provision.go index 5a0e1b121..282474303 100644 --- a/cli/provision.go +++ b/cli/provision.go @@ -4,7 +4,6 @@ package cli import ( - "context" "encoding/csv" "encoding/json" "errors" @@ -54,7 +53,7 @@ var cmdProvision = []cobra.Command{ return } - clients, err = sdk.CreateClients(context.Background(), clients, args[1], args[2]) + clients, err = sdk.CreateClients(cmd.Context(), clients, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -81,7 +80,7 @@ var cmdProvision = []cobra.Command{ var chs []smqsdk.Channel for _, c := range channels { - c, err = sdk.CreateChannel(context.Background(), c, args[1], args[2]) + c, err = sdk.CreateChannel(cmd.Context(), c, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -109,7 +108,7 @@ var cmdProvision = []cobra.Command{ return } for _, conn := range connIDs { - if err := sdk.Connect(context.Background(), conn, args[1], args[2]); err != nil { + if err := sdk.Connect(cmd.Context(), conn, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -146,13 +145,13 @@ var cmdProvision = []cobra.Command{ }, Status: smqsdk.EnabledStatus, } - user, err := sdk.CreateUser(context.Background(), user, "") + user, err := sdk.CreateUser(cmd.Context(), user, "") if err != nil { logErrorCmd(*cmd, err) return } - ut, err := sdk.CreateToken(context.Background(), smqsdk.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) + ut, err := sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) if err != nil { logErrorCmd(*cmd, err) return @@ -163,13 +162,13 @@ var cmdProvision = []cobra.Command{ Name: fmt.Sprintf("%s-domain", name), Status: smqsdk.EnabledStatus, } - domain, err = sdk.CreateDomain(context.Background(), domain, ut.AccessToken) + domain, err = sdk.CreateDomain(cmd.Context(), domain, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return } - ut, err = sdk.CreateToken(context.Background(), smqsdk.Login{Username: user.Email, Password: user.Credentials.Secret}) + ut, err = sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Email, Password: user.Credentials.Secret}) if err != nil { logErrorCmd(*cmd, err) return @@ -184,7 +183,7 @@ var cmdProvision = []cobra.Command{ clients = append(clients, t) } - clients, err = sdk.CreateClients(context.Background(), clients, domain.ID, ut.AccessToken) + clients, err = sdk.CreateClients(cmd.Context(), clients, domain.ID, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return @@ -196,7 +195,7 @@ var cmdProvision = []cobra.Command{ Name: fmt.Sprintf("%s-channel-%d", name, i), Status: smqsdk.EnabledStatus, } - c, err = sdk.CreateChannel(context.Background(), c, domain.ID, ut.AccessToken) + c, err = sdk.CreateChannel(cmd.Context(), c, domain.ID, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return @@ -211,7 +210,7 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[0].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } @@ -221,7 +220,7 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[0].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } @@ -231,21 +230,21 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[1].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } // send message to test connectivity - if err := sdk.SendMessage(context.Background(), domain.ID, channels[0].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } - if err := sdk.SendMessage(context.Background(), domain.ID, channels[0].ID, clients[1].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[1].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } - if err := sdk.SendMessage(context.Background(), domain.ID, channels[1].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[1].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cmd/alarms/main.go b/cmd/alarms/main.go new file mode 100644 index 000000000..23b0d91b0 --- /dev/null +++ b/cmd/alarms/main.go @@ -0,0 +1,192 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + + "github.com/absmach/magistrala/alarms" + httpAPI "github.com/absmach/magistrala/alarms/api" + "github.com/absmach/magistrala/alarms/consumer" + "github.com/absmach/magistrala/alarms/consumer/brokers" + "github.com/absmach/magistrala/alarms/middleware" + alarmsRepo "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/magistrala/pkg/prometheus" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/authn/authsvc" + authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/grpcclient" + "github.com/absmach/supermq/pkg/jaeger" + "github.com/absmach/supermq/pkg/messaging" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v11" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "alarms" + envPrefixDB = "MG_ALARMS_DB_" + envPrefixHTTP = "MG_ALARMS_HTTP_" + envPrefixAuth = "SMQ_AUTH_GRPC_" + defDB = "alarms" + defSvcHTTPPort = "8050" + envPrefixDomains = "SMQ_DOMAINS_GRPC_" +) + +type config struct { + LogLevel string `env:"MG_ALARMS_LOG_LEVEL" envDefault:"info"` + BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + InstanceID string `env:"MG_ALARMS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err.Error()) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + dbConfig := postgres.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + } + + db, err := postgres.Setup(dbConfig, *alarmsRepo.Migration()) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + repo := alarmsRepo.NewAlarmsRepo(db) + + authConfig := grpcclient.Config{} + if err := env.ParseWithOptions(&authConfig, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) + exitCode = 1 + return + } + authn, authnClient, err := authsvc.NewAuthentication(ctx, authConfig) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authConfig, domAuthz) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer authzHandler.Close() + + logger.Info("AuthZ successfully connected to auth gRPC server " + authzHandler.Secure()) + + idp := uuid.New() + + svc := alarms.NewService(idp, repo) + + svc = middleware.NewAuthorizationMiddleware(svc, authz) + svc = middleware.NewLoggingMiddleware(logger, svc) + counter, latency := prometheus.MakeMetrics("alarms", "api") + svc = middleware.NewMetricsMiddleware(counter, latency, svc) + svc = middleware.NewTracingMiddleware(tracer, svc) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, authn), logger) + + pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + exitCode = 1 + return + } + defer pubSub.Close() + pubSub = brokerstracing.NewPubSub(httpServerConfig, tracer, pubSub) + + consumer := consumer.Newhandler(svc, logger) + + subCfg := messaging.SubscriberConfig{ + ID: svcName, + Topic: brokers.AllTopic, + DeliveryPolicy: messaging.DeliverAllPolicy, + Handler: consumer, + } + if err := pubSub.Subscribe(ctx, subCfg); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to message broker: %s", err)) + exitCode = 1 + + return + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("billing service terminated: %s", err)) + } +} diff --git a/docker/.env b/docker/.env index 2d8e76b0f..adaf6664c 100644 --- a/docker/.env +++ b/docker/.env @@ -119,6 +119,23 @@ MG_EMAIL_FROM_ADDRESS=from@example.com MG_EMAIL_FROM_NAME=Example MG_EMAIL_TEMPLATE=email.tmpl +### Alarms +MG_ALARMS_LOG_LEVEL=debug +MG_ALARMS_HTTP_HOST=alarms +MG_ALARMS_HTTP_PORT=8050 +MG_ALARMS_HTTP_SERVER_CERT= +MG_ALARMS_HTTP_SERVER_KEY= +MG_ALARMS_DB_HOST=alarms-db +MG_ALARMS_DB_PORT=5432 +MG_ALARMS_DB_USER=magistrala +MG_ALARMS_DB_PASS=magistrala +MG_ALARMS_DB_NAME=alarms +MG_ALARMS_DB_SSL_MODE=disable +MG_ALARMS_DB_SSL_CERT= +MG_ALARMS_DB_SSL_KEY= +MG_ALARMS_DB_SSL_ROOT_CERT= +MG_ALARMS_INSTANCE_ID= + ### Certs SMQ_ADDONS_CERTS_PATH_PREFIX=./ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1301957e9..a6abe34e3 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -22,6 +22,7 @@ volumes: magistrala-ui-backend-db-volume: magistrala-re-db-volume: magistrala-auth-redis-volume: + magistrala-alarms-db-volume: services: ui: @@ -214,3 +215,91 @@ services: bind: create_host_path: true - ./re_config.toml:/config.toml + + alarms-db: + image: postgres:16.2-alpine + container_name: magistrala-alarms-db + restart: on-failure + command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_ALARMS_DB_USER} + POSTGRES_PASSWORD: ${MG_ALARMS_DB_PASS} + POSTGRES_DB: ${MG_ALARMS_DB_NAME} + ports: + - 6019:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-alarms-db-volume:/var/lib/postgresql/data + + alarms: + image: ghcr.io/absmach/magistrala/alarms:${MG_RELEASE_TAG} + container_name: magistrala-alarms + depends_on: + - alarms-db + restart: on-failure + environment: + MG_ALARMS_LOG_LEVEL: ${MG_ALARMS_LOG_LEVEL} + MG_ALARMS_HTTP_PORT: ${MG_ALARMS_HTTP_PORT} + MG_ALARMS_HTTP_HOST: ${MG_ALARMS_HTTP_HOST} + MG_ALARMS_HTTP_SERVER_CERT: ${MG_ALARMS_HTTP_SERVER_CERT} + MG_ALARMS_HTTP_SERVER_KEY: ${MG_ALARMS_HTTP_SERVER_KEY} + MG_ALARMS_DB_HOST: ${MG_ALARMS_DB_HOST} + MG_ALARMS_DB_PORT: ${MG_ALARMS_DB_PORT} + MG_ALARMS_DB_USER: ${MG_ALARMS_DB_USER} + MG_ALARMS_DB_PASS: ${MG_ALARMS_DB_PASS} + MG_ALARMS_DB_NAME: ${MG_ALARMS_DB_NAME} + MG_ALARMS_DB_SSL_MODE: ${MG_ALARMS_DB_SSL_MODE} + MG_ALARMS_DB_SSL_CERT: ${MG_ALARMS_DB_SSL_CERT} + MG_ALARMS_DB_SSL_KEY: ${MG_ALARMS_DB_SSL_KEY} + MG_ALARMS_DB_SSL_ROOT_CERT: ${MG_ALARMS_DB_SSL_ROOT_CERT} + SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} + SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} + SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} + SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} + SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} + SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} + SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} + SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ALARMS_INSTANCE_ID: ${MG_ALARMS_INSTANCE_ID} + ports: + - ${MG_ALARMS_HTTP_PORT}:${MG_ALARMS_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + # Auth gRPC client certificates + - type: bind + source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + target: /domains-grpc-server-ca${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true diff --git a/go.mod b/go.mod index 3f6edb06d..0b6a9202e 100644 --- a/go.mod +++ b/go.mod @@ -5,17 +5,19 @@ go 1.24.2 require ( github.com/0x6flab/namegenerator v1.4.0 github.com/absmach/callhome v0.14.0 - github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af + github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0 github.com/authzed/authzed-go v1.3.1-0.20250320210445-0cde0d8c71e2 github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8 github.com/caarlos0/env/v11 v11.3.1 github.com/eclipse/paho.mqtt.golang v1.5.0 + github.com/fatih/color v1.18.0 github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba github.com/go-chi/chi/v5 v5.2.1 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid/v5 v5.3.2 github.com/gookit/color v1.5.4 github.com/gorilla/websocket v1.5.3 + github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f github.com/ivanpirog/coloredcobra v1.0.1 github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgtype v1.14.4 @@ -41,21 +43,13 @@ require ( moul.io/http2curl v1.0.0 ) -require ( - filippo.io/edwards25519 v1.1.0 // indirect - github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/lib/pq v1.10.9 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/moby/sys/user v0.3.0 // indirect -) - require ( dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect + github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 // indirect github.com/absmach/senml v1.0.7 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -69,7 +63,6 @@ require ( github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect - github.com/fatih/color v1.18.0 github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/fxamacker/cbor/v2 v2.8.0 // indirect @@ -78,12 +71,13 @@ require ( github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect - github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -91,12 +85,15 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jzelinskie/stringz v0.0.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/sys/user v0.3.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/nats-io/nats.go v1.41.0 // indirect + github.com/nats-io/nats.go v1.41.0 github.com/nats-io/nkeys v0.4.10 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/oklog/ulid/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 994fd5e11..bd2476b71 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/absmach/mgate v0.4.5 h1:l6RmrEsR9jxkdb9WHUSecmT0HA41TkZZQVffFfUAIfI= github.com/absmach/mgate v0.4.5/go.mod h1:IvRIHZexZPEIAPmmaJF0L5DY2ERjj+GxRGitOW4s6qo= github.com/absmach/senml v1.0.7 h1:XLvpw0qxbP2QhOz7KLM2ZRar+vSCpSG/0o0kEvWx3No= github.com/absmach/senml v1.0.7/go.mod h1:3bRIiNc8hq7l3auMs8gQrpsM5hHy7iDuiLILrf/+MfA= -github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af h1:FKVu3CX1V+8mob0hvfD0havIWXp9Wcv8nHJOeNhWmUE= -github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af/go.mod h1:dJqO3luvt+zLVuDhxOdsRRCuv945F86Nf1BXxFro+nU= +github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0 h1:vPTeIjQPCgDCMGSA9pICZNcTNpz8fRl5CBlpVewMCwY= +github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0/go.mod h1:dJqO3luvt+zLVuDhxOdsRRCuv945F86Nf1BXxFro+nU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= diff --git a/pkg/sdk/bootstrap.go b/pkg/sdk/bootstrap.go index cc00a5920..3c7aa8317 100644 --- a/pkg/sdk/bootstrap.go +++ b/pkg/sdk/bootstrap.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -96,7 +97,7 @@ func (ts *BootstrapConfig) UnmarshalJSON(data []byte) error { return nil } -func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) { +func (sdk mgSDK) AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) { data, err := json.Marshal(cfg) if err != nil { return "", errors.NewSDKError(err) @@ -104,7 +105,7 @@ func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (stri url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint) - headers, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusOK, http.StatusCreated) + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusOK, http.StatusCreated) if sdkerr != nil { return "", sdkerr } @@ -114,14 +115,14 @@ func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (stri return id, nil } -func (sdk mgSDK) Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) { +func (sdk mgSDK) Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) { endpoint := fmt.Sprintf("%s/%s", domainID, configsEndpoint) url, err := sdk.withQueryParams(sdk.bootstrapURL, endpoint, pm) if err != nil { return BootstrapPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { return BootstrapPage{}, sdkerr } @@ -134,7 +135,7 @@ func (sdk mgSDK) Bootstraps(pm PageMetadata, domainID, token string) (BootstrapP return bb, nil } -func (sdk mgSDK) Whitelist(clientID string, state int, domainID, token string) errors.SDKError { +func (sdk mgSDK) Whitelist(ctx context.Context, clientID string, state int, domainID, token string) errors.SDKError { if clientID == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -146,18 +147,18 @@ func (sdk mgSDK) Whitelist(clientID string, state int, domainID, token string) e url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, whitelistEndpoint, clientID) - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusCreated, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusCreated, http.StatusOK) return sdkerr } -func (sdk mgSDK) ViewBootstrap(id, domainID, token string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, errors.SDKError) { if id == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) - _, body, err := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if err != nil { return BootstrapConfig{}, err } @@ -170,7 +171,7 @@ func (sdk mgSDK) ViewBootstrap(id, domainID, token string) (BootstrapConfig, err return bc, nil } -func (sdk mgSDK) UpdateBootstrap(cfg BootstrapConfig, domainID, token string) errors.SDKError { +func (sdk mgSDK) UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) errors.SDKError { if cfg.ClientID == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -181,12 +182,12 @@ func (sdk mgSDK) UpdateBootstrap(cfg BootstrapConfig, domainID, token string) er return errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) return sdkerr } -func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { if id == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } @@ -202,7 +203,7 @@ func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, t return BootstrapConfig{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodPatch, url, token, data, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) if sdkerr != nil { return BootstrapConfig{}, sdkerr } @@ -215,7 +216,7 @@ func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, t return bc, nil } -func (sdk mgSDK) UpdateBootstrapConnection(id string, channels []string, domainID, token string) errors.SDKError { +func (sdk mgSDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) errors.SDKError { if id == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -228,27 +229,27 @@ func (sdk mgSDK) UpdateBootstrapConnection(id string, channels []string, domainI return errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) return sdkerr } -func (sdk mgSDK) RemoveBootstrap(id, domainID, token string) errors.SDKError { +func (sdk mgSDK) RemoveBootstrap(ctx context.Context, id, domainID, token string) errors.SDKError { if id == "" { return errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) - _, _, err := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) return err } -func (sdk mgSDK) Bootstrap(externalID, externalKey string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, errors.SDKError) { if externalID == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, bootstrapEndpoint, externalID) - _, body, err := sdk.processRequest(http.MethodGet, url, smqSDK.ClientPrefix+externalKey, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, smqSDK.ClientPrefix+externalKey, nil, nil, http.StatusOK) if err != nil { return BootstrapConfig{}, err } @@ -261,7 +262,7 @@ func (sdk mgSDK) Bootstrap(externalID, externalKey string) (BootstrapConfig, err return bc, nil } -func (sdk mgSDK) BootstrapSecure(externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) { if externalID == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } @@ -272,7 +273,7 @@ func (sdk mgSDK) BootstrapSecure(externalID, externalKey, cryptoKey string) (Boo return BootstrapConfig{}, errors.NewSDKError(err) } - _, body, sdkErr := sdk.processRequest(http.MethodGet, url, smqSDK.ClientPrefix+encExtKey, nil, nil, http.StatusOK) + _, body, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, smqSDK.ClientPrefix+encExtKey, nil, nil, http.StatusOK) if sdkErr != nil { return BootstrapConfig{}, sdkErr } diff --git a/pkg/sdk/bootstrap_test.go b/pkg/sdk/bootstrap_test.go index aafbc03ab..16a13094c 100644 --- a/pkg/sdk/bootstrap_test.go +++ b/pkg/sdk/bootstrap_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -252,7 +253,7 @@ func TestAddBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("Add", mock.Anything, tc.session, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.AddBootstrap(tc.cfg, tc.domainID, tc.token) + resp, err := mgsdk.AddBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, bootstrapConfig.ClientID, resp) @@ -399,7 +400,7 @@ func TestListBootstraps(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("List", mock.Anything, tc.session, mock.Anything, tc.pageMeta.Offset, tc.pageMeta.Limit).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.Bootstraps(tc.pageMeta, tc.domainID, tc.token) + resp, err := mgsdk.Bootstraps(context.Background(), tc.pageMeta, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if err == nil { @@ -504,7 +505,7 @@ func TestWhiteList(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq).Return(tc.svcErr) - err := mgsdk.Whitelist(tc.clientID, tc.state, tc.domainID, tc.token) + err := mgsdk.Whitelist(context.Background(), tc.clientID, tc.state, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq) @@ -625,7 +626,7 @@ func TestViewBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("View", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.ViewBootstrap(tc.id, tc.domainID, tc.token) + resp, err := mgsdk.ViewBootstrap(context.Background(), tc.id, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if err == nil { @@ -788,7 +789,7 @@ func TestUpdateBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticationErr) svcCall := bsvc.On("Update", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr) - err := mgsdk.UpdateBootstrap(tc.cfg, tc.domainID, tc.token) + err := mgsdk.UpdateBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "Update", mock.Anything, tc.session, tc.svcReq) @@ -912,7 +913,7 @@ func TestUpdateBootstrapCerts(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("UpdateCert", mock.Anything, tc.session, tc.id, tc.clientCert, tc.clientKey, tc.caCert).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.UpdateBootstrapCerts(tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token) + resp, err := mgsdk.UpdateBootstrapCerts(context.Background(), tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, tc.response, resp) @@ -1015,7 +1016,7 @@ func TestUpdateBootstrapConnection(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels).Return(tc.svcErr) - err := mgsdk.UpdateBootstrapConnection(tc.id, tc.channels, tc.domainID, tc.token) + err := mgsdk.UpdateBootstrapConnection(context.Background(), tc.id, tc.channels, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels) @@ -1102,7 +1103,7 @@ func TestRemoveBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("Remove", mock.Anything, tc.session, tc.id).Return(tc.svcErr) - err := mgsdk.RemoveBootstrap(tc.id, tc.domainID, tc.token) + err := mgsdk.RemoveBootstrap(context.Background(), tc.id, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "Remove", mock.Anything, tc.session, tc.id) @@ -1207,7 +1208,7 @@ func TestBoostrap(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { svcCall := bsvc.On("Bootstrap", mock.Anything, tc.externalKey, tc.externalID, false).Return(tc.svcResp, tc.svcErr) readerCall := reader.On("ReadConfig", tc.svcResp, false).Return(tc.readerResp, tc.readerErr) - resp, err := mgsdk.Bootstrap(tc.externalID, tc.externalKey) + resp, err := mgsdk.Bootstrap(context.Background(), tc.externalID, tc.externalKey) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, tc.response, resp) @@ -1324,7 +1325,7 @@ func TestBootstrapSecure(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { svcCall := bsvc.On("Bootstrap", mock.Anything, mock.Anything, tc.externalID, true).Return(tc.svcResp, tc.svcErr) readerCall := reader.On("ReadConfig", tc.svcResp, true).Return(tc.readerResp, tc.readerErr) - resp, err := mgsdk.BootstrapSecure(tc.externalID, tc.externalKey, tc.cryptoKey) + resp, err := mgsdk.BootstrapSecure(context.Background(), tc.externalID, tc.externalKey, tc.cryptoKey) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, sdkBootsrapConfigRes, resp) diff --git a/pkg/sdk/consumers.go b/pkg/sdk/consumers.go index 3019db67e..60c2dcf0d 100644 --- a/pkg/sdk/consumers.go +++ b/pkg/sdk/consumers.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "encoding/json" "fmt" "net/http" @@ -21,7 +22,7 @@ type Subscription struct { Contact string `json:"contact,omitempty"` } -func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, errors.SDKError) { +func (sdk mgSDK) CreateSubscription(ctx context.Context, topic, contact, token string) (string, errors.SDKError) { sub := Subscription{ Topic: topic, Contact: contact, @@ -33,7 +34,7 @@ func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, error url := fmt.Sprintf("%s/%s", sdk.usersURL, subscriptionEndpoint) - headers, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated) + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusCreated) if sdkerr != nil { return "", sdkerr } @@ -43,13 +44,13 @@ func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, error return id, nil } -func (sdk mgSDK) ListSubscriptions(pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) { +func (sdk mgSDK) ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) { url, err := sdk.withQueryParams(sdk.usersURL, subscriptionEndpoint, pm) if err != nil { return SubscriptionPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { return SubscriptionPage{}, sdkerr } @@ -62,10 +63,10 @@ func (sdk mgSDK) ListSubscriptions(pm PageMetadata, token string) (SubscriptionP return sp, nil } -func (sdk mgSDK) ViewSubscription(id, token string) (Subscription, errors.SDKError) { +func (sdk mgSDK) ViewSubscription(ctx context.Context, id, token string) (Subscription, errors.SDKError) { url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) - _, body, err := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if err != nil { return Subscription{}, err } @@ -78,10 +79,10 @@ func (sdk mgSDK) ViewSubscription(id, token string) (Subscription, errors.SDKErr return sub, nil } -func (sdk mgSDK) DeleteSubscription(id, token string) errors.SDKError { +func (sdk mgSDK) DeleteSubscription(ctx context.Context, id, token string) errors.SDKError { url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) - _, _, err := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) return err } diff --git a/pkg/sdk/consumers_test.go b/pkg/sdk/consumers_test.go index 6fc997719..bd07c4af4 100644 --- a/pkg/sdk/consumers_test.go +++ b/pkg/sdk/consumers_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -142,7 +143,7 @@ func TestCreateSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("CreateSubscription", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - loc, err := mgsdk.CreateSubscription(tc.subscription.Topic, tc.subscription.Contact, tc.token) + loc, err := mgsdk.CreateSubscription(context.Background(), tc.subscription.Topic, tc.subscription.Contact, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.empty, loc == "") if tc.err == nil { @@ -214,7 +215,7 @@ func TestViewSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("ViewSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.ViewSubscription(tc.subID, tc.token) + resp, err := mgsdk.ViewSubscription(context.Background(), tc.subID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { @@ -374,7 +375,7 @@ func TestListSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("ListSubscriptions", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.ListSubscriptions(tc.pageMeta, tc.token) + resp, err := mgsdk.ListSubscriptions(context.Background(), tc.pageMeta, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { @@ -444,7 +445,7 @@ func TestDeleteSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("RemoveSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcErr) - err := mgsdk.DeleteSubscription(tc.subID, tc.token) + err := mgsdk.DeleteSubscription(context.Background(), tc.subID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "RemoveSubscription", mock.Anything, tc.token, tc.subID) diff --git a/pkg/sdk/messages.go b/pkg/sdk/messages.go index 3dec4cf50..b6d0e9931 100644 --- a/pkg/sdk/messages.go +++ b/pkg/sdk/messages.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "encoding/json" "fmt" "net/http" @@ -16,7 +17,7 @@ import ( const channelParts = 2 -func (sdk mgSDK) ReadMessages(pm MessagePageMetadata, chanName, domainID, token string) (MessagesPage, errors.SDKError) { +func (sdk mgSDK) ReadMessages(ctx context.Context, pm MessagePageMetadata, chanName, domainID, token string) (MessagesPage, errors.SDKError) { chanNameParts := strings.SplitN(chanName, ".", channelParts) chanID := chanNameParts[0] subtopicPart := "" @@ -32,7 +33,7 @@ func (sdk mgSDK) ReadMessages(pm MessagePageMetadata, chanName, domainID, token header := make(map[string]string) header["Content-Type"] = string(sdk.msgContentType) - _, body, sdkerr := sdk.processRequest(http.MethodGet, msgURL, token, nil, header, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, msgURL, token, nil, header, http.StatusOK) if sdkerr != nil { return MessagesPage{}, sdkerr } diff --git a/pkg/sdk/messages_test.go b/pkg/sdk/messages_test.go index c8fce8af4..f52d01ccd 100644 --- a/pkg/sdk/messages_test.go +++ b/pkg/sdk/messages_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "net/http" "net/http/httptest" "testing" @@ -229,7 +230,7 @@ func TestReadMessages(t *testing.T) { authCall1 := authn.On("Authenticate", mock.Anything, tc.token).Return(smqauthn.Session{UserID: validID}, tc.authnErr) authzCall := channelsGRPCClient.On("Authorize", mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, tc.authzErr) repoCall := repo.On("ReadAll", channelID, mock.Anything).Return(tc.repoRes, tc.repoErr) - response, err := mgsdk.ReadMessages(tc.messagePageMeta, tc.chanName, tc.domainID, tc.token) + response, err := mgsdk.ReadMessages(context.Background(), tc.messagePageMeta, tc.chanName, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, response) if tc.err == nil { diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index d67ff4203..91eb56458 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -92,8 +92,8 @@ func (_c *SDK_AcceptInvitation_Call) RunAndReturn(run func(ctx context.Context, } // AddBootstrap provides a mock function for the type SDK -func (_mock *SDK) AddBootstrap(cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError) { - ret := _mock.Called(cfg, domainID, token) +func (_mock *SDK) AddBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) if len(ret) == 0 { panic("no return value specified for AddBootstrap") @@ -101,16 +101,16 @@ func (_mock *SDK) AddBootstrap(cfg sdk.BootstrapConfig, domainID string, token s var r0 string var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) (string, errors.SDKError)); ok { - return returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) string); ok { - r0 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) string); ok { + r0 = returnFunc(ctx, cfg, domainID, token) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(sdk.BootstrapConfig, string, string) errors.SDKError); ok { - r1 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -125,16 +125,17 @@ type SDK_AddBootstrap_Call struct { } // AddBootstrap is a helper method to define mock.On call +// - ctx // - cfg // - domainID // - token -func (_e *SDK_Expecter) AddBootstrap(cfg interface{}, domainID interface{}, token interface{}) *SDK_AddBootstrap_Call { - return &SDK_AddBootstrap_Call{Call: _e.mock.On("AddBootstrap", cfg, domainID, token)} +func (_e *SDK_Expecter) AddBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_AddBootstrap_Call { + return &SDK_AddBootstrap_Call{Call: _e.mock.On("AddBootstrap", ctx, cfg, domainID, token)} } -func (_c *SDK_AddBootstrap_Call) Run(run func(cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_AddBootstrap_Call { +func (_c *SDK_AddBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_AddBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.BootstrapConfig), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.BootstrapConfig), args[2].(string), args[3].(string)) }) return _c } @@ -144,7 +145,7 @@ func (_c *SDK_AddBootstrap_Call) Return(s string, sDKError errors.SDKError) *SDK return _c } -func (_c *SDK_AddBootstrap_Call) RunAndReturn(run func(cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError)) *SDK_AddBootstrap_Call { +func (_c *SDK_AddBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError)) *SDK_AddBootstrap_Call { _c.Call.Return(run) return _c } @@ -756,8 +757,8 @@ func (_c *SDK_AvailableGroupRoleActions_Call) RunAndReturn(run func(ctx context. } // Bootstrap provides a mock function for the type SDK -func (_mock *SDK) Bootstrap(externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(externalID, externalKey) +func (_mock *SDK) Bootstrap(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey) if len(ret) == 0 { panic("no return value specified for Bootstrap") @@ -765,16 +766,16 @@ func (_mock *SDK) Bootstrap(externalID string, externalKey string) (sdk.Bootstra var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey) } - if returnFunc, ok := ret.Get(0).(func(string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { - r1 = returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -789,15 +790,16 @@ type SDK_Bootstrap_Call struct { } // Bootstrap is a helper method to define mock.On call +// - ctx // - externalID // - externalKey -func (_e *SDK_Expecter) Bootstrap(externalID interface{}, externalKey interface{}) *SDK_Bootstrap_Call { - return &SDK_Bootstrap_Call{Call: _e.mock.On("Bootstrap", externalID, externalKey)} +func (_e *SDK_Expecter) Bootstrap(ctx interface{}, externalID interface{}, externalKey interface{}) *SDK_Bootstrap_Call { + return &SDK_Bootstrap_Call{Call: _e.mock.On("Bootstrap", ctx, externalID, externalKey)} } -func (_c *SDK_Bootstrap_Call) Run(run func(externalID string, externalKey string)) *SDK_Bootstrap_Call { +func (_c *SDK_Bootstrap_Call) Run(run func(ctx context.Context, externalID string, externalKey string)) *SDK_Bootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -807,14 +809,14 @@ func (_c *SDK_Bootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKErr return _c } -func (_c *SDK_Bootstrap_Call) RunAndReturn(run func(externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_Bootstrap_Call { +func (_c *SDK_Bootstrap_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_Bootstrap_Call { _c.Call.Return(run) return _c } // BootstrapSecure provides a mock function for the type SDK -func (_mock *SDK) BootstrapSecure(externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(externalID, externalKey, cryptoKey) +func (_mock *SDK) BootstrapSecure(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey, cryptoKey) if len(ret) == 0 { panic("no return value specified for BootstrapSecure") @@ -822,16 +824,16 @@ func (_mock *SDK) BootstrapSecure(externalID string, externalKey string, cryptoK var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey, cryptoKey) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey, cryptoKey) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey, cryptoKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -846,16 +848,17 @@ type SDK_BootstrapSecure_Call struct { } // BootstrapSecure is a helper method to define mock.On call +// - ctx // - externalID // - externalKey // - cryptoKey -func (_e *SDK_Expecter) BootstrapSecure(externalID interface{}, externalKey interface{}, cryptoKey interface{}) *SDK_BootstrapSecure_Call { - return &SDK_BootstrapSecure_Call{Call: _e.mock.On("BootstrapSecure", externalID, externalKey, cryptoKey)} +func (_e *SDK_Expecter) BootstrapSecure(ctx interface{}, externalID interface{}, externalKey interface{}, cryptoKey interface{}) *SDK_BootstrapSecure_Call { + return &SDK_BootstrapSecure_Call{Call: _e.mock.On("BootstrapSecure", ctx, externalID, externalKey, cryptoKey)} } -func (_c *SDK_BootstrapSecure_Call) Run(run func(externalID string, externalKey string, cryptoKey string)) *SDK_BootstrapSecure_Call { +func (_c *SDK_BootstrapSecure_Call) Run(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string)) *SDK_BootstrapSecure_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -865,14 +868,14 @@ func (_c *SDK_BootstrapSecure_Call) Return(bootstrapConfig sdk.BootstrapConfig, return _c } -func (_c *SDK_BootstrapSecure_Call) RunAndReturn(run func(externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_BootstrapSecure_Call { +func (_c *SDK_BootstrapSecure_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_BootstrapSecure_Call { _c.Call.Return(run) return _c } // Bootstraps provides a mock function for the type SDK -func (_mock *SDK) Bootstraps(pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) { - ret := _mock.Called(pm, domainID, token) +func (_mock *SDK) Bootstraps(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) if len(ret) == 0 { panic("no return value specified for Bootstraps") @@ -880,16 +883,16 @@ func (_mock *SDK) Bootstraps(pm sdk.PageMetadata, domainID string, token string) var r0 sdk.BootstrapPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok { - return returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.BootstrapPage); ok { - r0 = returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.BootstrapPage); ok { + r0 = returnFunc(ctx, pm, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok { - r1 = returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -904,16 +907,17 @@ type SDK_Bootstraps_Call struct { } // Bootstraps is a helper method to define mock.On call +// - ctx // - pm // - domainID // - token -func (_e *SDK_Expecter) Bootstraps(pm interface{}, domainID interface{}, token interface{}) *SDK_Bootstraps_Call { - return &SDK_Bootstraps_Call{Call: _e.mock.On("Bootstraps", pm, domainID, token)} +func (_e *SDK_Expecter) Bootstraps(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_Bootstraps_Call { + return &SDK_Bootstraps_Call{Call: _e.mock.On("Bootstraps", ctx, pm, domainID, token)} } -func (_c *SDK_Bootstraps_Call) Run(run func(pm sdk.PageMetadata, domainID string, token string)) *SDK_Bootstraps_Call { +func (_c *SDK_Bootstraps_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_Bootstraps_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.PageMetadata), args[2].(string), args[3].(string)) }) return _c } @@ -923,7 +927,7 @@ func (_c *SDK_Bootstraps_Call) Return(bootstrapPage sdk.BootstrapPage, sDKError return _c } -func (_c *SDK_Bootstraps_Call) RunAndReturn(run func(pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError)) *SDK_Bootstraps_Call { +func (_c *SDK_Bootstraps_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError)) *SDK_Bootstraps_Call { _c.Call.Return(run) return _c } @@ -2106,8 +2110,8 @@ func (_c *SDK_CreateGroupRole_Call) RunAndReturn(run func(ctx context.Context, i } // CreateSubscription provides a mock function for the type SDK -func (_mock *SDK) CreateSubscription(topic string, contact string, token string) (string, errors.SDKError) { - ret := _mock.Called(topic, contact, token) +func (_mock *SDK) CreateSubscription(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, topic, contact, token) if len(ret) == 0 { panic("no return value specified for CreateSubscription") @@ -2115,16 +2119,16 @@ func (_mock *SDK) CreateSubscription(topic string, contact string, token string) var r0 string var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (string, errors.SDKError)); ok { - return returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, topic, contact, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) string); ok { - r0 = returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) string); ok { + r0 = returnFunc(ctx, topic, contact, token) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, topic, contact, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -2139,16 +2143,17 @@ type SDK_CreateSubscription_Call struct { } // CreateSubscription is a helper method to define mock.On call +// - ctx // - topic // - contact // - token -func (_e *SDK_Expecter) CreateSubscription(topic interface{}, contact interface{}, token interface{}) *SDK_CreateSubscription_Call { - return &SDK_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", topic, contact, token)} +func (_e *SDK_Expecter) CreateSubscription(ctx interface{}, topic interface{}, contact interface{}, token interface{}) *SDK_CreateSubscription_Call { + return &SDK_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", ctx, topic, contact, token)} } -func (_c *SDK_CreateSubscription_Call) Run(run func(topic string, contact string, token string)) *SDK_CreateSubscription_Call { +func (_c *SDK_CreateSubscription_Call) Run(run func(ctx context.Context, topic string, contact string, token string)) *SDK_CreateSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -2158,7 +2163,7 @@ func (_c *SDK_CreateSubscription_Call) Return(s string, sDKError errors.SDKError return _c } -func (_c *SDK_CreateSubscription_Call) RunAndReturn(run func(topic string, contact string, token string) (string, errors.SDKError)) *SDK_CreateSubscription_Call { +func (_c *SDK_CreateSubscription_Call) RunAndReturn(run func(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError)) *SDK_CreateSubscription_Call { _c.Call.Return(run) return _c } @@ -2629,16 +2634,16 @@ func (_c *SDK_DeleteInvitation_Call) RunAndReturn(run func(ctx context.Context, } // DeleteSubscription provides a mock function for the type SDK -func (_mock *SDK) DeleteSubscription(id string, token string) errors.SDKError { - ret := _mock.Called(id, token) +func (_mock *SDK) DeleteSubscription(ctx context.Context, id string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, token) if len(ret) == 0 { panic("no return value specified for DeleteSubscription") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) errors.SDKError); ok { - r0 = returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -2653,15 +2658,16 @@ type SDK_DeleteSubscription_Call struct { } // DeleteSubscription is a helper method to define mock.On call +// - ctx // - id // - token -func (_e *SDK_Expecter) DeleteSubscription(id interface{}, token interface{}) *SDK_DeleteSubscription_Call { - return &SDK_DeleteSubscription_Call{Call: _e.mock.On("DeleteSubscription", id, token)} +func (_e *SDK_Expecter) DeleteSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_DeleteSubscription_Call { + return &SDK_DeleteSubscription_Call{Call: _e.mock.On("DeleteSubscription", ctx, id, token)} } -func (_c *SDK_DeleteSubscription_Call) Run(run func(id string, token string)) *SDK_DeleteSubscription_Call { +func (_c *SDK_DeleteSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_DeleteSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -2671,7 +2677,7 @@ func (_c *SDK_DeleteSubscription_Call) Return(sDKError errors.SDKError) *SDK_Del return _c } -func (_c *SDK_DeleteSubscription_Call) RunAndReturn(run func(id string, token string) errors.SDKError) *SDK_DeleteSubscription_Call { +func (_c *SDK_DeleteSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) errors.SDKError) *SDK_DeleteSubscription_Call { _c.Call.Return(run) return _c } @@ -4748,8 +4754,8 @@ func (_c *SDK_ListGroupMembers_Call) RunAndReturn(run func(ctx context.Context, } // ListSubscriptions provides a mock function for the type SDK -func (_mock *SDK) ListSubscriptions(pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError) { - ret := _mock.Called(pm, token) +func (_mock *SDK) ListSubscriptions(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, token) if len(ret) == 0 { panic("no return value specified for ListSubscriptions") @@ -4757,16 +4763,16 @@ func (_mock *SDK) ListSubscriptions(pm sdk.PageMetadata, token string) (sdk.Subs var r0 sdk.SubscriptionPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.SubscriptionPage, errors.SDKError)); ok { - return returnFunc(pm, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) (sdk.SubscriptionPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.SubscriptionPage); ok { - r0 = returnFunc(pm, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) sdk.SubscriptionPage); ok { + r0 = returnFunc(ctx, pm, token) } else { r0 = ret.Get(0).(sdk.SubscriptionPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { - r1 = returnFunc(pm, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -4781,15 +4787,16 @@ type SDK_ListSubscriptions_Call struct { } // ListSubscriptions is a helper method to define mock.On call +// - ctx // - pm // - token -func (_e *SDK_Expecter) ListSubscriptions(pm interface{}, token interface{}) *SDK_ListSubscriptions_Call { - return &SDK_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", pm, token)} +func (_e *SDK_Expecter) ListSubscriptions(ctx interface{}, pm interface{}, token interface{}) *SDK_ListSubscriptions_Call { + return &SDK_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", ctx, pm, token)} } -func (_c *SDK_ListSubscriptions_Call) Run(run func(pm sdk.PageMetadata, token string)) *SDK_ListSubscriptions_Call { +func (_c *SDK_ListSubscriptions_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, token string)) *SDK_ListSubscriptions_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string)) + run(args[0].(context.Context), args[1].(sdk.PageMetadata), args[2].(string)) }) return _c } @@ -4799,14 +4806,14 @@ func (_c *SDK_ListSubscriptions_Call) Return(subscriptionPage sdk.SubscriptionPa return _c } -func (_c *SDK_ListSubscriptions_Call) RunAndReturn(run func(pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError)) *SDK_ListSubscriptions_Call { +func (_c *SDK_ListSubscriptions_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError)) *SDK_ListSubscriptions_Call { _c.Call.Return(run) return _c } // ReadMessages provides a mock function for the type SDK -func (_mock *SDK) ReadMessages(pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError) { - ret := _mock.Called(pm, chanID, domainID, token) +func (_mock *SDK) ReadMessages(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, chanID, domainID, token) if len(ret) == 0 { panic("no return value specified for ReadMessages") @@ -4814,16 +4821,16 @@ func (_mock *SDK) ReadMessages(pm sdk.MessagePageMetadata, chanID string, domain var r0 sdk.MessagesPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.MessagePageMetadata, string, string, string) (sdk.MessagesPage, errors.SDKError)); ok { - return returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) (sdk.MessagesPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, chanID, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.MessagePageMetadata, string, string, string) sdk.MessagesPage); ok { - r0 = returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) sdk.MessagesPage); ok { + r0 = returnFunc(ctx, pm, chanID, domainID, token) } else { r0 = ret.Get(0).(sdk.MessagesPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.MessagePageMetadata, string, string, string) errors.SDKError); ok { - r1 = returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.MessagePageMetadata, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, chanID, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -4838,17 +4845,18 @@ type SDK_ReadMessages_Call struct { } // ReadMessages is a helper method to define mock.On call +// - ctx // - pm // - chanID // - domainID // - token -func (_e *SDK_Expecter) ReadMessages(pm interface{}, chanID interface{}, domainID interface{}, token interface{}) *SDK_ReadMessages_Call { - return &SDK_ReadMessages_Call{Call: _e.mock.On("ReadMessages", pm, chanID, domainID, token)} +func (_e *SDK_Expecter) ReadMessages(ctx interface{}, pm interface{}, chanID interface{}, domainID interface{}, token interface{}) *SDK_ReadMessages_Call { + return &SDK_ReadMessages_Call{Call: _e.mock.On("ReadMessages", ctx, pm, chanID, domainID, token)} } -func (_c *SDK_ReadMessages_Call) Run(run func(pm sdk.MessagePageMetadata, chanID string, domainID string, token string)) *SDK_ReadMessages_Call { +func (_c *SDK_ReadMessages_Call) Run(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string)) *SDK_ReadMessages_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.MessagePageMetadata), args[1].(string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(sdk.MessagePageMetadata), args[2].(string), args[3].(string), args[4].(string)) }) return _c } @@ -4858,7 +4866,7 @@ func (_c *SDK_ReadMessages_Call) Return(messagesPage sdk.MessagesPage, sDKError return _c } -func (_c *SDK_ReadMessages_Call) RunAndReturn(run func(pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError)) *SDK_ReadMessages_Call { +func (_c *SDK_ReadMessages_Call) RunAndReturn(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError)) *SDK_ReadMessages_Call { _c.Call.Return(run) return _c } @@ -5322,16 +5330,16 @@ func (_c *SDK_RemoveAllGroupRoleMembers_Call) RunAndReturn(run func(ctx context. } // RemoveBootstrap provides a mock function for the type SDK -func (_mock *SDK) RemoveBootstrap(id string, domainID string, token string) errors.SDKError { - ret := _mock.Called(id, domainID, token) +func (_mock *SDK) RemoveBootstrap(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) if len(ret) == 0 { panic("no return value specified for RemoveBootstrap") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) errors.SDKError); ok { - r0 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -5346,16 +5354,17 @@ type SDK_RemoveBootstrap_Call struct { } // RemoveBootstrap is a helper method to define mock.On call +// - ctx // - id // - domainID // - token -func (_e *SDK_Expecter) RemoveBootstrap(id interface{}, domainID interface{}, token interface{}) *SDK_RemoveBootstrap_Call { - return &SDK_RemoveBootstrap_Call{Call: _e.mock.On("RemoveBootstrap", id, domainID, token)} +func (_e *SDK_Expecter) RemoveBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_RemoveBootstrap_Call { + return &SDK_RemoveBootstrap_Call{Call: _e.mock.On("RemoveBootstrap", ctx, id, domainID, token)} } -func (_c *SDK_RemoveBootstrap_Call) Run(run func(id string, domainID string, token string)) *SDK_RemoveBootstrap_Call { +func (_c *SDK_RemoveBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_RemoveBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -5365,7 +5374,7 @@ func (_c *SDK_RemoveBootstrap_Call) Return(sDKError errors.SDKError) *SDK_Remove return _c } -func (_c *SDK_RemoveBootstrap_Call) RunAndReturn(run func(id string, domainID string, token string) errors.SDKError) *SDK_RemoveBootstrap_Call { +func (_c *SDK_RemoveBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_RemoveBootstrap_Call { _c.Call.Return(run) return _c } @@ -6147,8 +6156,8 @@ func (_c *SDK_SendInvitation_Call) RunAndReturn(run func(ctx context.Context, in } // SendMessage provides a mock function for the type SDK -func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string, secret string, msg string) errors.SDKError { - ret := _mock.Called(ctx, domainID, topic, secret, msg) +func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string, msg string, secret string) errors.SDKError { + ret := _mock.Called(ctx, domainID, topic, msg, secret) if len(ret) == 0 { panic("no return value specified for SendMessage") @@ -6156,7 +6165,7 @@ func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string var r0 errors.SDKError if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) errors.SDKError); ok { - r0 = returnFunc(ctx, domainID, topic, secret, msg) + r0 = returnFunc(ctx, domainID, topic, msg, secret) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6174,13 +6183,13 @@ type SDK_SendMessage_Call struct { // - ctx // - domainID // - topic -// - secret // - msg -func (_e *SDK_Expecter) SendMessage(ctx interface{}, domainID interface{}, topic interface{}, secret interface{}, msg interface{}) *SDK_SendMessage_Call { - return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, domainID, topic, secret, msg)} +// - secret +func (_e *SDK_Expecter) SendMessage(ctx interface{}, domainID interface{}, topic interface{}, msg interface{}, secret interface{}) *SDK_SendMessage_Call { + return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, domainID, topic, msg, secret)} } -func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, domainID string, topic string, secret string, msg string)) *SDK_SendMessage_Call { +func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, domainID string, topic string, msg string, secret string)) *SDK_SendMessage_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) }) @@ -6192,7 +6201,7 @@ func (_c *SDK_SendMessage_Call) Return(sDKError errors.SDKError) *SDK_SendMessag return _c } -func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domainID string, topic string, secret string, msg string) errors.SDKError) *SDK_SendMessage_Call { +func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domainID string, topic string, msg string, secret string) errors.SDKError) *SDK_SendMessage_Call { _c.Call.Return(run) return _c } @@ -6398,16 +6407,16 @@ func (_c *SDK_SetGroupParent_Call) RunAndReturn(run func(ctx context.Context, id } // UpdateBootstrap provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrap(cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError { - ret := _mock.Called(cfg, domainID, token) +func (_mock *SDK) UpdateBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, cfg, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrap") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) errors.SDKError); ok { - r0 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, cfg, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6422,16 +6431,17 @@ type SDK_UpdateBootstrap_Call struct { } // UpdateBootstrap is a helper method to define mock.On call +// - ctx // - cfg // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrap(cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrap_Call { - return &SDK_UpdateBootstrap_Call{Call: _e.mock.On("UpdateBootstrap", cfg, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrap_Call { + return &SDK_UpdateBootstrap_Call{Call: _e.mock.On("UpdateBootstrap", ctx, cfg, domainID, token)} } -func (_c *SDK_UpdateBootstrap_Call) Run(run func(cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_UpdateBootstrap_Call { +func (_c *SDK_UpdateBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_UpdateBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.BootstrapConfig), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.BootstrapConfig), args[2].(string), args[3].(string)) }) return _c } @@ -6441,14 +6451,14 @@ func (_c *SDK_UpdateBootstrap_Call) Return(sDKError errors.SDKError) *SDK_Update return _c } -func (_c *SDK_UpdateBootstrap_Call) RunAndReturn(run func(cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrap_Call { +func (_c *SDK_UpdateBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrap_Call { _c.Call.Return(run) return _c } // UpdateBootstrapCerts provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrapCerts(id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(id, clientCert, clientKey, ca, domainID, token) +func (_mock *SDK) UpdateBootstrapCerts(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, clientCert, clientKey, ca, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrapCerts") @@ -6456,16 +6466,16 @@ func (_mock *SDK) UpdateBootstrapCerts(id string, clientCert string, clientKey s var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string, string, string) errors.SDKError); ok { - r1 = returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -6480,19 +6490,20 @@ type SDK_UpdateBootstrapCerts_Call struct { } // UpdateBootstrapCerts is a helper method to define mock.On call +// - ctx // - id // - clientCert // - clientKey // - ca // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrapCerts(id interface{}, clientCert interface{}, clientKey interface{}, ca interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapCerts_Call { - return &SDK_UpdateBootstrapCerts_Call{Call: _e.mock.On("UpdateBootstrapCerts", id, clientCert, clientKey, ca, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrapCerts(ctx interface{}, id interface{}, clientCert interface{}, clientKey interface{}, ca interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapCerts_Call { + return &SDK_UpdateBootstrapCerts_Call{Call: _e.mock.On("UpdateBootstrapCerts", ctx, id, clientCert, clientKey, ca, domainID, token)} } -func (_c *SDK_UpdateBootstrapCerts_Call) Run(run func(id string, clientCert string, clientKey string, ca string, domainID string, token string)) *SDK_UpdateBootstrapCerts_Call { +func (_c *SDK_UpdateBootstrapCerts_Call) Run(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string)) *SDK_UpdateBootstrapCerts_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string), args[6].(string)) }) return _c } @@ -6502,22 +6513,22 @@ func (_c *SDK_UpdateBootstrapCerts_Call) Return(bootstrapConfig sdk.BootstrapCon return _c } -func (_c *SDK_UpdateBootstrapCerts_Call) RunAndReturn(run func(id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_UpdateBootstrapCerts_Call { +func (_c *SDK_UpdateBootstrapCerts_Call) RunAndReturn(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_UpdateBootstrapCerts_Call { _c.Call.Return(run) return _c } // UpdateBootstrapConnection provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrapConnection(id string, channels []string, domainID string, token string) errors.SDKError { - ret := _mock.Called(id, channels, domainID, token) +func (_mock *SDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, channels, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrapConnection") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, []string, string, string) errors.SDKError); ok { - r0 = returnFunc(id, channels, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, channels, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6532,17 +6543,18 @@ type SDK_UpdateBootstrapConnection_Call struct { } // UpdateBootstrapConnection is a helper method to define mock.On call +// - ctx // - id // - channels // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrapConnection(id interface{}, channels interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapConnection_Call { - return &SDK_UpdateBootstrapConnection_Call{Call: _e.mock.On("UpdateBootstrapConnection", id, channels, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrapConnection(ctx interface{}, id interface{}, channels interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapConnection_Call { + return &SDK_UpdateBootstrapConnection_Call{Call: _e.mock.On("UpdateBootstrapConnection", ctx, id, channels, domainID, token)} } -func (_c *SDK_UpdateBootstrapConnection_Call) Run(run func(id string, channels []string, domainID string, token string)) *SDK_UpdateBootstrapConnection_Call { +func (_c *SDK_UpdateBootstrapConnection_Call) Run(run func(ctx context.Context, id string, channels []string, domainID string, token string)) *SDK_UpdateBootstrapConnection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].([]string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].([]string), args[3].(string), args[4].(string)) }) return _c } @@ -6552,7 +6564,7 @@ func (_c *SDK_UpdateBootstrapConnection_Call) Return(sDKError errors.SDKError) * return _c } -func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(id string, channels []string, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapConnection_Call { +func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapConnection_Call { _c.Call.Return(run) return _c } @@ -7733,8 +7745,8 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk0.Pag } // ViewBootstrap provides a mock function for the type SDK -func (_mock *SDK) ViewBootstrap(id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(id, domainID, token) +func (_mock *SDK) ViewBootstrap(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) if len(ret) == 0 { panic("no return value specified for ViewBootstrap") @@ -7742,16 +7754,16 @@ func (_mock *SDK) ViewBootstrap(id string, domainID string, token string) (sdk.B var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -7766,16 +7778,17 @@ type SDK_ViewBootstrap_Call struct { } // ViewBootstrap is a helper method to define mock.On call +// - ctx // - id // - domainID // - token -func (_e *SDK_Expecter) ViewBootstrap(id interface{}, domainID interface{}, token interface{}) *SDK_ViewBootstrap_Call { - return &SDK_ViewBootstrap_Call{Call: _e.mock.On("ViewBootstrap", id, domainID, token)} +func (_e *SDK_Expecter) ViewBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewBootstrap_Call { + return &SDK_ViewBootstrap_Call{Call: _e.mock.On("ViewBootstrap", ctx, id, domainID, token)} } -func (_c *SDK_ViewBootstrap_Call) Run(run func(id string, domainID string, token string)) *SDK_ViewBootstrap_Call { +func (_c *SDK_ViewBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -7785,7 +7798,7 @@ func (_c *SDK_ViewBootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sD return _c } -func (_c *SDK_ViewBootstrap_Call) RunAndReturn(run func(id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_ViewBootstrap_Call { +func (_c *SDK_ViewBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_ViewBootstrap_Call { _c.Call.Return(run) return _c } @@ -7909,8 +7922,8 @@ func (_c *SDK_ViewCertByClient_Call) RunAndReturn(run func(ctx context.Context, } // ViewSubscription provides a mock function for the type SDK -func (_mock *SDK) ViewSubscription(id string, token string) (sdk.Subscription, errors.SDKError) { - ret := _mock.Called(id, token) +func (_mock *SDK) ViewSubscription(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError) { + ret := _mock.Called(ctx, id, token) if len(ret) == 0 { panic("no return value specified for ViewSubscription") @@ -7918,16 +7931,16 @@ func (_mock *SDK) ViewSubscription(id string, token string) (sdk.Subscription, e var r0 sdk.Subscription var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) (sdk.Subscription, errors.SDKError)); ok { - return returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.Subscription, errors.SDKError)); ok { + return returnFunc(ctx, id, token) } - if returnFunc, ok := ret.Get(0).(func(string, string) sdk.Subscription); ok { - r0 = returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.Subscription); ok { + r0 = returnFunc(ctx, id, token) } else { r0 = ret.Get(0).(sdk.Subscription) } - if returnFunc, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { - r1 = returnFunc(id, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -7942,15 +7955,16 @@ type SDK_ViewSubscription_Call struct { } // ViewSubscription is a helper method to define mock.On call +// - ctx // - id // - token -func (_e *SDK_Expecter) ViewSubscription(id interface{}, token interface{}) *SDK_ViewSubscription_Call { - return &SDK_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", id, token)} +func (_e *SDK_Expecter) ViewSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_ViewSubscription_Call { + return &SDK_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", ctx, id, token)} } -func (_c *SDK_ViewSubscription_Call) Run(run func(id string, token string)) *SDK_ViewSubscription_Call { +func (_c *SDK_ViewSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_ViewSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -7960,22 +7974,22 @@ func (_c *SDK_ViewSubscription_Call) Return(subscription sdk.Subscription, sDKEr return _c } -func (_c *SDK_ViewSubscription_Call) RunAndReturn(run func(id string, token string) (sdk.Subscription, errors.SDKError)) *SDK_ViewSubscription_Call { +func (_c *SDK_ViewSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError)) *SDK_ViewSubscription_Call { _c.Call.Return(run) return _c } // Whitelist provides a mock function for the type SDK -func (_mock *SDK) Whitelist(clientID string, state int, domainID string, token string) errors.SDKError { - ret := _mock.Called(clientID, state, domainID, token) +func (_mock *SDK) Whitelist(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, clientID, state, domainID, token) if len(ret) == 0 { panic("no return value specified for Whitelist") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, int, string, string) errors.SDKError); ok { - r0 = returnFunc(clientID, state, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, clientID, state, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -7990,17 +8004,18 @@ type SDK_Whitelist_Call struct { } // Whitelist is a helper method to define mock.On call +// - ctx // - clientID // - state // - domainID // - token -func (_e *SDK_Expecter) Whitelist(clientID interface{}, state interface{}, domainID interface{}, token interface{}) *SDK_Whitelist_Call { - return &SDK_Whitelist_Call{Call: _e.mock.On("Whitelist", clientID, state, domainID, token)} +func (_e *SDK_Expecter) Whitelist(ctx interface{}, clientID interface{}, state interface{}, domainID interface{}, token interface{}) *SDK_Whitelist_Call { + return &SDK_Whitelist_Call{Call: _e.mock.On("Whitelist", ctx, clientID, state, domainID, token)} } -func (_c *SDK_Whitelist_Call) Run(run func(clientID string, state int, domainID string, token string)) *SDK_Whitelist_Call { +func (_c *SDK_Whitelist_Call) Run(run func(ctx context.Context, clientID string, state int, domainID string, token string)) *SDK_Whitelist_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(int), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(string), args[4].(string)) }) return _c } @@ -8010,7 +8025,7 @@ func (_c *SDK_Whitelist_Call) Return(sDKError errors.SDKError) *SDK_Whitelist_Ca return _c } -func (_c *SDK_Whitelist_Call) RunAndReturn(run func(clientID string, state int, domainID string, token string) errors.SDKError) *SDK_Whitelist_Call { +func (_c *SDK_Whitelist_Call) RunAndReturn(run func(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError) *SDK_Whitelist_Call { _c.Call.Return(run) return _c } diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 7806246c7..2069df19c 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -5,6 +5,7 @@ package sdk import ( "bytes" + "context" "crypto/tls" "encoding/json" "fmt" @@ -67,16 +68,16 @@ type SDK interface { // ExternalKey: "externalKey", // Channels: []string{"channel1", "channel2"}, // } - // id, _ := sdk.AddBootstrap(cfg, "domainID", "token") + // id, _ := sdk.AddBootstrap(ctx, cfg, "domainID", "token") // fmt.Println(id) - AddBootstrap(cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) + AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) // View returns Client Config with given ID belonging to the user identified by the given token. // // example: - // bootstrap, _ := sdk.ViewBootstrap("id", "domainID", "token") + // bootstrap, _ := sdk.ViewBootstrap(ctx, "id", "domainID", "token") // fmt.Println(bootstrap) - ViewBootstrap(id, domainID, token string) (BootstrapConfig, errors.SDKError) + ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, errors.SDKError) // Update updates editable fields of the provided Config. // @@ -88,44 +89,44 @@ type SDK interface { // ExternalKey: "externalKey", // Channels: []string{"channel1", "channel2"}, // } - // err := sdk.UpdateBootstrap(cfg, "domainID", "token") + // err := sdk.UpdateBootstrap(ctx, cfg, "domainID", "token") // fmt.Println(err) - UpdateBootstrap(cfg BootstrapConfig, domainID, token string) errors.SDKError + UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) errors.SDKError // Update bootstrap config certificates. // // example: - // err := sdk.UpdateBootstrapCerts("id", "clientCert", "clientKey", "ca", "domainID", "token") + // err := sdk.UpdateBootstrapCerts(ctx, "id", "clientCert", "clientKey", "ca", "domainID", "token") // fmt.Println(err) - UpdateBootstrapCerts(id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, errors.SDKError) + UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, errors.SDKError) // UpdateBootstrapConnection updates connections performs update of the channel list corresponding Client is connected to. // // example: - // err := sdk.UpdateBootstrapConnection("id", []string{"channel1", "channel2"}, "domainID", "token") + // err := sdk.UpdateBootstrapConnection(ctx, "id", []string{"channel1", "channel2"}, "domainID", "token") // fmt.Println(err) - UpdateBootstrapConnection(id string, channels []string, domainID, token string) errors.SDKError + UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) errors.SDKError // Remove removes Config with specified token that belongs to the user identified by the given token. // // example: - // err := sdk.RemoveBootstrap("id", "domainID", "token") + // err := sdk.RemoveBootstrap(ctx, "id", "domainID", "token") // fmt.Println(err) - RemoveBootstrap(id, domainID, token string) errors.SDKError + RemoveBootstrap(ctx context.Context, id, domainID, token string) errors.SDKError // Bootstrap returns Config to the Client with provided external ID using external key. // // example: - // bootstrap, _ := sdk.Bootstrap("externalID", "externalKey") + // bootstrap, _ := sdk.Bootstrap(ctx, "externalID", "externalKey") // fmt.Println(bootstrap) - Bootstrap(externalID, externalKey string) (BootstrapConfig, errors.SDKError) + Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, errors.SDKError) // BootstrapSecure retrieves a configuration with given external ID and encrypted external key. // // example: - // bootstrap, _ := sdk.BootstrapSecure("externalID", "externalKey", "cryptoKey") + // bootstrap, _ := sdk.BootstrapSecure(ctx, "externalID", "externalKey", "cryptoKey") // fmt.Println(bootstrap) - BootstrapSecure(externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) + BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) // Bootstraps retrieves a list of managed configs. // @@ -134,16 +135,16 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // bootstraps, _ := sdk.Bootstraps(pm, "domainID", "token") + // bootstraps, _ := sdk.Bootstraps(ctx, pm, "domainID", "token") // fmt.Println(bootstraps) - Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) + Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) // Whitelist updates Client state Config with given ID belonging to the user identified by the given token. // // example: - // err := sdk.Whitelist("clientID", 1, "domainID", "token") + // err := sdk.Whitelist(ctx, "clientID", 1, "domainID", "token") // fmt.Println(err) - Whitelist(clientID string, state int, domainID, token string) errors.SDKError + Whitelist(ctx context.Context, clientID string, state int, domainID, token string) errors.SDKError // ReadMessages read messages of specified channel. // @@ -152,16 +153,16 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // msgs, _ := sdk.ReadMessages(pm,"channelID", "domainID", "token") + // msgs, _ := sdk.ReadMessages(ctx, pm,"channelID", "domainID", "token") // fmt.Println(msgs) - ReadMessages(pm MessagePageMetadata, chanID, domainID, token string) (MessagesPage, errors.SDKError) + ReadMessages(ctx context.Context, pm MessagePageMetadata, chanID, domainID, token string) (MessagesPage, errors.SDKError) // CreateSubscription creates a new subscription // // example: - // subscription, _ := sdk.CreateSubscription("topic", "contact", "token") + // subscription, _ := sdk.CreateSubscription(ctx, "topic", "contact", "token") // fmt.Println(subscription) - CreateSubscription(topic, contact, token string) (string, errors.SDKError) + CreateSubscription(ctx context.Context, topic, contact, token string) (string, errors.SDKError) // ListSubscriptions list subscriptions given list parameters. // @@ -170,23 +171,23 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // subscriptions, _ := sdk.ListSubscriptions(pm, "token") + // subscriptions, _ := sdk.ListSubscriptions(ctx, pm, "token") // fmt.Println(subscriptions) - ListSubscriptions(pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) + ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) // ViewSubscription retrieves a subscription with the provided id. // // example: - // subscription, _ := sdk.ViewSubscription("id", "token") + // subscription, _ := sdk.ViewSubscription(ctx, "id", "token") // fmt.Println(subscription) - ViewSubscription(id, token string) (Subscription, errors.SDKError) + ViewSubscription(ctx context.Context, id, token string) (Subscription, errors.SDKError) // DeleteSubscription removes a subscription with the provided id. // // example: - // err := sdk.DeleteSubscription("id", "token") + // err := sdk.DeleteSubscription(ctx, "id", "token") // fmt.Println(err) - DeleteSubscription(id, token string) errors.SDKError + DeleteSubscription(ctx context.Context, id, token string) errors.SDKError } type mgSDK struct { @@ -257,8 +258,8 @@ func NewSDK(conf Config) SDK { // processRequest creates and send a new HTTP request, and checks for errors in the HTTP response. // It then returns the response headers, the response body, and the associated error(s) (if any). -func (sdk mgSDK) processRequest(method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) { - req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data)) +func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) { + req, err := http.NewRequestWithContext(ctx, method, reqUrl, bytes.NewReader(data)) if err != nil { return make(http.Header), []byte{}, errors.NewSDKError(err) } diff --git a/provision/api/endpoint.go b/provision/api/endpoint.go index 2cdb55010..02c88ad74 100644 --- a/provision/api/endpoint.go +++ b/provision/api/endpoint.go @@ -13,13 +13,13 @@ import ( ) func doProvision(svc provision.Service) endpoint.Endpoint { - return func(_ context.Context, request interface{}) (interface{}, error) { + return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(provisionReq) if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - res, err := svc.Provision(req.domainID, req.token, req.Name, req.ExternalID, req.ExternalKey) + res, err := svc.Provision(ctx, req.domainID, req.token, req.Name, req.ExternalID, req.ExternalKey) if err != nil { return nil, err } @@ -38,13 +38,13 @@ func doProvision(svc provision.Service) endpoint.Endpoint { } func getMapping(svc provision.Service) endpoint.Endpoint { - return func(_ context.Context, request interface{}) (interface{}, error) { + return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(mappingReq) if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - res, err := svc.Mapping(req.token) + res, err := svc.Mapping(ctx, req.token) if err != nil { return nil, err } diff --git a/provision/api/endpoint_test.go b/provision/api/endpoint_test.go index 13f8eae7e..8ec142104 100644 --- a/provision/api/endpoint_test.go +++ b/provision/api/endpoint_test.go @@ -19,6 +19,7 @@ import ( smqlog "github.com/absmach/supermq/logger" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) var ( @@ -140,7 +141,7 @@ func TestProvision(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repocall := svc.On("Provision", validID, tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) + repocall := svc.On("Provision", mock.Anything, validID, tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) req := testRequest{ client: is.Client(), method: http.MethodPost, @@ -205,7 +206,7 @@ func TestMapping(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repocall := svc.On("Mapping", tc.token).Return(map[string]interface{}{}, tc.svcErr) + repocall := svc.On("Mapping", mock.Anything, tc.token).Return(map[string]interface{}{}, tc.svcErr) req := testRequest{ client: is.Client(), method: http.MethodGet, diff --git a/provision/api/logging.go b/provision/api/logging.go index 2156c722f..81b83c303 100644 --- a/provision/api/logging.go +++ b/provision/api/logging.go @@ -6,6 +6,7 @@ package api import ( + "context" "log/slog" "time" @@ -24,7 +25,7 @@ func NewLoggingMiddleware(svc provision.Service, logger *slog.Logger) provision. return &loggingMiddleware{logger, svc} } -func (lm *loggingMiddleware) Provision(domainID, token, name, externalID, externalKey string) (res provision.Result, err error) { +func (lm *loggingMiddleware) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res provision.Result, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -39,10 +40,10 @@ func (lm *loggingMiddleware) Provision(domainID, token, name, externalID, extern lm.logger.Info("Provision completed successfully", args...) }(time.Now()) - return lm.svc.Provision(domainID, token, name, externalID, externalKey) + return lm.svc.Provision(ctx, domainID, token, name, externalID, externalKey) } -func (lm *loggingMiddleware) Cert(domainID, token, clientID, duration string) (cert, key string, err error) { +func (lm *loggingMiddleware) Cert(ctx context.Context, domainID, token, clientID, duration string) (cert, key string, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -57,10 +58,10 @@ func (lm *loggingMiddleware) Cert(domainID, token, clientID, duration string) (c lm.logger.Info("Client certificate created successfully", args...) }(time.Now()) - return lm.svc.Cert(domainID, token, clientID, duration) + return lm.svc.Cert(ctx, domainID, token, clientID, duration) } -func (lm *loggingMiddleware) Mapping(token string) (res map[string]interface{}, err error) { +func (lm *loggingMiddleware) Mapping(ctx context.Context, token string) (res map[string]interface{}, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -73,5 +74,5 @@ func (lm *loggingMiddleware) Mapping(token string) (res map[string]interface{}, lm.logger.Info("Mapping completed successfully", args...) }(time.Now()) - return lm.svc.Mapping(token) + return lm.svc.Mapping(ctx, token) } diff --git a/provision/mocks/service.go b/provision/mocks/service.go index 68e708164..58c6fb6b4 100644 --- a/provision/mocks/service.go +++ b/provision/mocks/service.go @@ -8,6 +8,8 @@ package mocks import ( + "context" + "github.com/absmach/magistrala/provision" mock "github.com/stretchr/testify/mock" ) @@ -40,8 +42,8 @@ func (_m *Service) EXPECT() *Service_Expecter { } // Cert provides a mock function for the type Service -func (_mock *Service) Cert(domainID string, token string, clientID string, duration string) (string, string, error) { - ret := _mock.Called(domainID, token, clientID, duration) +func (_mock *Service) Cert(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error) { + ret := _mock.Called(ctx, domainID, token, clientID, duration) if len(ret) == 0 { panic("no return value specified for Cert") @@ -50,21 +52,21 @@ func (_mock *Service) Cert(domainID string, token string, clientID string, durat var r0 string var r1 string var r2 error - if returnFunc, ok := ret.Get(0).(func(string, string, string, string) (string, string, error)); ok { - return returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) (string, string, error)); ok { + return returnFunc(ctx, domainID, token, clientID, duration) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string) string); ok { - r0 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) string); ok { + r0 = returnFunc(ctx, domainID, token, clientID, duration) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string) string); ok { - r1 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string) string); ok { + r1 = returnFunc(ctx, domainID, token, clientID, duration) } else { r1 = ret.Get(1).(string) } - if returnFunc, ok := ret.Get(2).(func(string, string, string, string) error); ok { - r2 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(2).(func(context.Context, string, string, string, string) error); ok { + r2 = returnFunc(ctx, domainID, token, clientID, duration) } else { r2 = ret.Error(2) } @@ -77,17 +79,18 @@ type Service_Cert_Call struct { } // Cert is a helper method to define mock.On call +// - ctx // - domainID // - token // - clientID // - duration -func (_e *Service_Expecter) Cert(domainID interface{}, token interface{}, clientID interface{}, duration interface{}) *Service_Cert_Call { - return &Service_Cert_Call{Call: _e.mock.On("Cert", domainID, token, clientID, duration)} +func (_e *Service_Expecter) Cert(ctx interface{}, domainID interface{}, token interface{}, clientID interface{}, duration interface{}) *Service_Cert_Call { + return &Service_Cert_Call{Call: _e.mock.On("Cert", ctx, domainID, token, clientID, duration)} } -func (_c *Service_Cert_Call) Run(run func(domainID string, token string, clientID string, duration string)) *Service_Cert_Call { +func (_c *Service_Cert_Call) Run(run func(ctx context.Context, domainID string, token string, clientID string, duration string)) *Service_Cert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) }) return _c } @@ -97,14 +100,14 @@ func (_c *Service_Cert_Call) Return(s string, s1 string, err error) *Service_Cer return _c } -func (_c *Service_Cert_Call) RunAndReturn(run func(domainID string, token string, clientID string, duration string) (string, string, error)) *Service_Cert_Call { +func (_c *Service_Cert_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error)) *Service_Cert_Call { _c.Call.Return(run) return _c } // Mapping provides a mock function for the type Service -func (_mock *Service) Mapping(token string) (map[string]interface{}, error) { - ret := _mock.Called(token) +func (_mock *Service) Mapping(ctx context.Context, token string) (map[string]interface{}, error) { + ret := _mock.Called(ctx, token) if len(ret) == 0 { panic("no return value specified for Mapping") @@ -112,18 +115,18 @@ func (_mock *Service) Mapping(token string) (map[string]interface{}, error) { var r0 map[string]interface{} var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (map[string]interface{}, error)); ok { - return returnFunc(token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (map[string]interface{}, error)); ok { + return returnFunc(ctx, token) } - if returnFunc, ok := ret.Get(0).(func(string) map[string]interface{}); ok { - r0 = returnFunc(token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) map[string]interface{}); ok { + r0 = returnFunc(ctx, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[string]interface{}) } } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, token) } else { r1 = ret.Error(1) } @@ -136,14 +139,15 @@ type Service_Mapping_Call struct { } // Mapping is a helper method to define mock.On call +// - ctx // - token -func (_e *Service_Expecter) Mapping(token interface{}) *Service_Mapping_Call { - return &Service_Mapping_Call{Call: _e.mock.On("Mapping", token)} +func (_e *Service_Expecter) Mapping(ctx interface{}, token interface{}) *Service_Mapping_Call { + return &Service_Mapping_Call{Call: _e.mock.On("Mapping", ctx, token)} } -func (_c *Service_Mapping_Call) Run(run func(token string)) *Service_Mapping_Call { +func (_c *Service_Mapping_Call) Run(run func(ctx context.Context, token string)) *Service_Mapping_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(context.Context), args[1].(string)) }) return _c } @@ -153,14 +157,14 @@ func (_c *Service_Mapping_Call) Return(stringToIfaceVal map[string]interface{}, return _c } -func (_c *Service_Mapping_Call) RunAndReturn(run func(token string) (map[string]interface{}, error)) *Service_Mapping_Call { +func (_c *Service_Mapping_Call) RunAndReturn(run func(ctx context.Context, token string) (map[string]interface{}, error)) *Service_Mapping_Call { _c.Call.Return(run) return _c } // Provision provides a mock function for the type Service -func (_mock *Service) Provision(domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error) { - ret := _mock.Called(domainID, token, name, externalID, externalKey) +func (_mock *Service) Provision(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error) { + ret := _mock.Called(ctx, domainID, token, name, externalID, externalKey) if len(ret) == 0 { panic("no return value specified for Provision") @@ -168,16 +172,16 @@ func (_mock *Service) Provision(domainID string, token string, name string, exte var r0 provision.Result var r1 error - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string) (provision.Result, error)); ok { - return returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (provision.Result, error)); ok { + return returnFunc(ctx, domainID, token, name, externalID, externalKey) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string) provision.Result); ok { - r0 = returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) provision.Result); ok { + r0 = returnFunc(ctx, domainID, token, name, externalID, externalKey) } else { r0 = ret.Get(0).(provision.Result) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string, string) error); ok { - r1 = returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok { + r1 = returnFunc(ctx, domainID, token, name, externalID, externalKey) } else { r1 = ret.Error(1) } @@ -190,18 +194,19 @@ type Service_Provision_Call struct { } // Provision is a helper method to define mock.On call +// - ctx // - domainID // - token // - name // - externalID // - externalKey -func (_e *Service_Expecter) Provision(domainID interface{}, token interface{}, name interface{}, externalID interface{}, externalKey interface{}) *Service_Provision_Call { - return &Service_Provision_Call{Call: _e.mock.On("Provision", domainID, token, name, externalID, externalKey)} +func (_e *Service_Expecter) Provision(ctx interface{}, domainID interface{}, token interface{}, name interface{}, externalID interface{}, externalKey interface{}) *Service_Provision_Call { + return &Service_Provision_Call{Call: _e.mock.On("Provision", ctx, domainID, token, name, externalID, externalKey)} } -func (_c *Service_Provision_Call) Run(run func(domainID string, token string, name string, externalID string, externalKey string)) *Service_Provision_Call { +func (_c *Service_Provision_Call) Run(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string)) *Service_Provision_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string)) }) return _c } @@ -211,7 +216,7 @@ func (_c *Service_Provision_Call) Return(result provision.Result, err error) *Se return _c } -func (_c *Service_Provision_Call) RunAndReturn(run func(domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error)) *Service_Provision_Call { +func (_c *Service_Provision_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error)) *Service_Provision_Call { _c.Call.Return(run) return _c } diff --git a/provision/service.go b/provision/service.go index 405937719..f4628564f 100644 --- a/provision/service.go +++ b/provision/service.go @@ -56,18 +56,18 @@ type Service interface { // - create multiple Channels // - create Bootstrap configuration // - whitelist Client in Bootstrap configuration == connect Client to Channels - Provision(domainID, token, name, externalID, externalKey string) (Result, error) + Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (Result, error) // Mapping returns current configuration used for provision // useful for using in ui to create configuration that matches // one created with Provision method. - Mapping(token string) (map[string]interface{}, error) + Mapping(ctx context.Context, token string) (map[string]interface{}, error) // Certs creates certificate for clients that communicate over mTLS // A duration string is a possibly signed sequence of decimal numbers, // each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". - Cert(domainID, token, clientID, duration string) (string, string, error) + Cert(ctx context.Context, domainID, token, clientID, duration string) (string, string, error) } type provisionService struct { @@ -97,13 +97,13 @@ func New(cfg Config, mgsdk sdk.SDK, logger *slog.Logger) Service { } // Mapping retrieves current configuration. -func (ps *provisionService) Mapping(token string) (map[string]interface{}, error) { +func (ps *provisionService) Mapping(ctx context.Context, token string) (map[string]interface{}, error) { pm := smqSDK.PageMetadata{ Offset: uint64(offset), Limit: uint64(limit), } - if _, err := ps.sdk.Users(context.Background(), pm, token); err != nil { + if _, err := ps.sdk.Users(ctx, pm, token); err != nil { return map[string]interface{}{}, errors.Wrap(ErrUnauthorized, err) } @@ -112,12 +112,12 @@ func (ps *provisionService) Mapping(token string) (map[string]interface{}, error // Provision is provision method for creating setup according to // provision layout specified in config.toml. -func (ps *provisionService) Provision(domainID, token, name, externalID, externalKey string) (res Result, err error) { +func (ps *provisionService) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res Result, err error) { var channels []smqSDK.Channel var clients []smqSDK.Client - defer ps.recover(&err, &clients, &channels, &domainID, &token) + defer ps.recover(ctx, &err, &clients, &channels, &domainID, &token) - token, err = ps.createTokenIfEmpty(token) + token, err = ps.createTokenIfEmpty(ctx, token) if err != nil { return res, errors.Wrap(ErrFailedToCreateToken, err) } @@ -142,14 +142,14 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa name = c.Name } cli.Name = name - cli, err := ps.sdk.CreateClient(context.Background(), cli, domainID, token) + cli, err := ps.sdk.CreateClient(ctx, cli, domainID, token) if err != nil { res.Error = err.Error() return res, errors.Wrap(ErrFailedClientCreation, err) } // Get newly created client (in order to get the key). - cli, err = ps.sdk.Client(context.Background(), cli.ID, domainID, token) + cli, err = ps.sdk.Client(ctx, cli.ID, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("client id: %s", cli.ID)) return res, errors.Wrap(ErrFailedClientRetrieval, e) @@ -162,11 +162,11 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa Name: name + "_" + channel.Name, Metadata: smqSDK.Metadata(channel.Metadata), } - ch, err := ps.sdk.CreateChannel(context.Background(), ch, domainID, token) + ch, err := ps.sdk.CreateChannel(ctx, ch, domainID, token) if err != nil { return res, errors.Wrap(ErrFailedChannelCreation, err) } - ch, err = ps.sdk.Channel(context.Background(), ch.ID, domainID, token) + ch, err = ps.sdk.Channel(ctx, ch.ID, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("channel id: %s", ch.ID)) return res, errors.Wrap(ErrFailedChannelRetrieval, e) @@ -206,12 +206,12 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa ClientKey: cert.Key, Content: string(content), } - bsid, err := ps.sdk.AddBootstrap(bsReq, domainID, token) + bsid, err := ps.sdk.AddBootstrap(ctx, bsReq, domainID, token) if err != nil { return Result{}, errors.Wrap(ErrFailedBootstrap, err) } - bsConfig, err = ps.sdk.ViewBootstrap(bsid, domainID, token) + bsConfig, err = ps.sdk.ViewBootstrap(ctx, bsid, domainID, token) if err != nil { return Result{}, errors.Wrap(ErrFailedBootstrapValidate, err) } @@ -220,12 +220,12 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa if ps.conf.Bootstrap.X509Provision { var cert smqSDK.Cert - cert, err = ps.sdk.IssueCert(context.Background(), c.ID, ps.conf.Cert.TTL, domainID, token) + cert, err = ps.sdk.IssueCert(ctx, c.ID, ps.conf.Cert.TTL, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("client id: %s", c.ID)) return res, errors.Wrap(ErrFailedCertCreation, e) } - cert, err := ps.sdk.ViewCert(context.Background(), cert.SerialNumber, domainID, token) + cert, err := ps.sdk.ViewCert(ctx, cert.SerialNumber, domainID, token) if err != nil { return res, errors.Wrap(ErrFailedCertView, err) } @@ -235,14 +235,14 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa res.CACert = "" if needsBootstrap(c) { - if _, err = ps.sdk.UpdateBootstrapCerts(bsConfig.ClientID, cert.Certificate, cert.Key, "", domainID, token); err != nil { + if _, err = ps.sdk.UpdateBootstrapCerts(ctx, bsConfig.ClientID, cert.Certificate, cert.Key, "", domainID, token); err != nil { return Result{}, errors.Wrap(ErrFailedCertCreation, err) } } } if ps.conf.Bootstrap.AutoWhiteList { - if err := ps.sdk.Whitelist(c.ID, Active, domainID, token); err != nil { + if err := ps.sdk.Whitelist(ctx, c.ID, Active, domainID, token); err != nil { res.Error = err.Error() return res, ErrClientUpdate } @@ -250,18 +250,17 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa } } - if err = ps.updateGateway(domainID, token, bsConfig, channels); err != nil { + if err = ps.updateGateway(ctx, domainID, token, bsConfig, channels); err != nil { return res, err } return res, nil } -func (ps *provisionService) Cert(domainID, token, clientID, ttl string) (string, string, error) { - token, err := ps.createTokenIfEmpty(token) +func (ps *provisionService) Cert(ctx context.Context, domainID, token, clientID, ttl string) (string, string, error) { + token, err := ps.createTokenIfEmpty(ctx, token) if err != nil { return "", "", errors.Wrap(ErrFailedToCreateToken, err) } - ctx := context.Background() th, err := ps.sdk.Client(ctx, clientID, domainID, token) if err != nil { @@ -278,7 +277,7 @@ func (ps *provisionService) Cert(domainID, token, clientID, ttl string) (string, return cert.Certificate, cert.Key, err } -func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { +func (ps *provisionService) createTokenIfEmpty(ctx context.Context, token string) (string, error) { if token != "" { return token, nil } @@ -298,7 +297,7 @@ func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { Username: ps.conf.Server.MgUsername, Password: ps.conf.Server.MgPass, } - tkn, err := ps.sdk.CreateToken(context.Background(), u) + tkn, err := ps.sdk.CreateToken(ctx, u) if err != nil { return token, errors.Wrap(ErrFailedToCreateToken, err) } @@ -306,7 +305,7 @@ func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { return tkn.AccessToken, nil } -func (ps *provisionService) updateGateway(domainID, token string, bs sdk.BootstrapConfig, channels []smqSDK.Channel) error { +func (ps *provisionService) updateGateway(ctx context.Context, domainID, token string, bs sdk.BootstrapConfig, channels []smqSDK.Channel) error { var gw Gateway for _, ch := range channels { switch ch.Metadata["type"] { @@ -323,7 +322,7 @@ func (ps *provisionService) updateGateway(domainID, token string, bs sdk.Bootstr gw.CfgID = bs.ClientID gw.Type = gateway - c, sdkerr := ps.sdk.Client(context.Background(), bs.ClientID, domainID, token) + c, sdkerr := ps.sdk.Client(ctx, bs.ClientID, domainID, token) if sdkerr != nil { return errors.Wrap(ErrGatewayUpdate, sdkerr) } @@ -334,7 +333,7 @@ func (ps *provisionService) updateGateway(domainID, token string, bs sdk.Bootstr if err := json.Unmarshal(b, &c.Metadata); err != nil { return errors.Wrap(ErrGatewayUpdate, err) } - if _, err := ps.sdk.UpdateClient(context.Background(), c, domainID, token); err != nil { + if _, err := ps.sdk.UpdateClient(ctx, c, domainID, token); err != nil { return errors.Wrap(ErrGatewayUpdate, err) } return nil @@ -346,18 +345,18 @@ func (ps *provisionService) errLog(err error) { } } -func clean(ps *provisionService, clients []smqSDK.Client, channels []smqSDK.Channel, domainID, token string) { +func clean(ctx context.Context, ps *provisionService, clients []smqSDK.Client, channels []smqSDK.Channel, domainID, token string) { for _, t := range clients { - err := ps.sdk.DeleteClient(context.Background(), t.ID, domainID, token) + err := ps.sdk.DeleteClient(ctx, t.ID, domainID, token) ps.errLog(err) } for _, c := range channels { - err := ps.sdk.DeleteChannel(context.Background(), c.ID, domainID, token) + err := ps.sdk.DeleteChannel(ctx, c.ID, domainID, token) ps.errLog(err) } } -func (ps *provisionService) recover(e *error, ths *[]smqSDK.Client, chs *[]smqSDK.Channel, dm, tkn *string) { +func (ps *provisionService) recover(ctx context.Context, e *error, ths *[]smqSDK.Client, chs *[]smqSDK.Channel, dm, tkn *string) { if e == nil { return } @@ -365,49 +364,49 @@ func (ps *provisionService) recover(e *error, ths *[]smqSDK.Client, chs *[]smqSD if errors.Contains(err, ErrFailedClientRetrieval) || errors.Contains(err, ErrFailedChannelCreation) { for _, c := range clients { - err := ps.sdk.DeleteClient(context.Background(), c.ID, domainID, token) + err := ps.sdk.DeleteClient(ctx, c.ID, domainID, token) ps.errLog(err) } return } if errors.Contains(err, ErrFailedBootstrap) || errors.Contains(err, ErrFailedChannelRetrieval) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) return } if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if needsBootstrap(th) { - ps.errLog(ps.sdk.RemoveBootstrap(th.ID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, th.ID, domainID, token)) } } return } if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if needsBootstrap(th) { - bs, err := ps.sdk.ViewBootstrap(th.ID, domainID, token) + bs, err := ps.sdk.ViewBootstrap(ctx, th.ID, domainID, token) ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) - ps.errLog(ps.sdk.RemoveBootstrap(bs.ClientID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) } } } if errors.Contains(err, ErrClientUpdate) || errors.Contains(err, ErrGatewayUpdate) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if ps.conf.Bootstrap.X509Provision && needsBootstrap(th) { - _, err := ps.sdk.RevokeCert(context.Background(), th.ID, domainID, token) + _, err := ps.sdk.RevokeCert(ctx, th.ID, domainID, token) ps.errLog(err) } if needsBootstrap(th) { - bs, err := ps.sdk.ViewBootstrap(th.ID, domainID, token) + bs, err := ps.sdk.ViewBootstrap(ctx, th.ID, domainID, token) ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) - ps.errLog(ps.sdk.RemoveBootstrap(bs.ClientID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) } } return diff --git a/provision/service_test.go b/provision/service_test.go index 37a119bfe..a329abeb5 100644 --- a/provision/service_test.go +++ b/provision/service_test.go @@ -4,6 +4,7 @@ package provision_test import ( + "context" "fmt" "testing" @@ -51,8 +52,8 @@ func TestMapping(t *testing.T) { for _, c := range cases { t.Run(c.desc, func(t *testing.T) { pm := smqSDK.PageMetadata{Offset: uint64(0), Limit: uint64(10)} - repocall := mgsdk.On("Users", pm, c.token).Return(smqSDK.UsersPage{}, c.sdkerr) - content, err := svc.Mapping(c.token) + repocall := mgsdk.On("Users", mock.Anything, pm, c.token).Return(smqSDK.UsersPage{}, c.sdkerr) + content, err := svc.Mapping(context.Background(), c.token) assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected error %v, got %v", c.err, err)) assert.Equal(t, c.content, content) repocall.Unset() @@ -215,15 +216,15 @@ func TestCert(t *testing.T) { mgsdk := new(sdkmocks.SDK) svc := provision.New(c.config, mgsdk, smqlog.NewMock()) - mgsdk.On("Client", c.clientID, c.domainID, mock.Anything).Return(smqSDK.Client{ID: c.clientID}, c.sdkClientErr) - mgsdk.On("IssueCert", c.clientID, c.config.Cert.TTL, c.domainID, mock.Anything).Return(smqSDK.Cert{SerialNumber: c.serial}, c.sdkCertErr) - mgsdk.On("ViewCert", c.serial, mock.Anything, mock.Anything).Return(smqSDK.Cert{Certificate: c.cert, Key: c.key}, c.sdkCertErr) + mgsdk.On("Client", mock.Anything, c.clientID, c.domainID, mock.Anything).Return(smqSDK.Client{ID: c.clientID}, c.sdkClientErr) + mgsdk.On("IssueCert", mock.Anything, c.clientID, c.config.Cert.TTL, c.domainID, mock.Anything).Return(smqSDK.Cert{SerialNumber: c.serial}, c.sdkCertErr) + mgsdk.On("ViewCert", mock.Anything, c.serial, mock.Anything, mock.Anything).Return(smqSDK.Cert{Certificate: c.cert, Key: c.key}, c.sdkCertErr) login := smqSDK.Login{ Username: c.config.Server.MgUsername, Password: c.config.Server.MgPass, } - mgsdk.On("CreateToken", login).Return(smqSDK.Token{AccessToken: validToken}, c.sdkTokenErr) - cert, key, err := svc.Cert(c.domainID, c.token, c.clientID, c.ttl) + mgsdk.On("CreateToken", mock.Anything, login).Return(smqSDK.Token{AccessToken: validToken}, c.sdkTokenErr) + cert, key, err := svc.Cert(context.Background(), c.domainID, c.token, c.clientID, c.ttl) assert.Equal(t, c.cert, cert) assert.Equal(t, c.key, key) assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected error %v, got %v", c.err, err)) diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 8bcea0690..225da972f 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -33,3 +33,7 @@ packages: github.com/absmach/magistrala/provision: interfaces: Service: + github.com/absmach/magistrala/alarms: + interfaces: + Service: + Repository: \ No newline at end of file diff --git a/tools/e2e/cmd/main.go b/tools/e2e/cmd/main.go index a1d10e03b..3a6d6715a 100644 --- a/tools/e2e/cmd/main.go +++ b/tools/e2e/cmd/main.go @@ -30,8 +30,8 @@ func main() { "go run tools/e2e/cmd/main.go\n" + "go run tools/e2e/cmd/main.go --host 142.93.118.47\n" + "go run tools/e2e/cmd/main.go --host localhost --num 10 --num_of_messages 100 --prefix e2e", - Run: func(_ *cobra.Command, _ []string) { - e2e.Test(econf) + Run: func(cmd *cobra.Command, _ []string) { + e2e.Test(cmd.Context(), econf) }, } diff --git a/tools/e2e/e2e.go b/tools/e2e/e2e.go index d3622f671..4d2b5e13f 100644 --- a/tools/e2e/e2e.go +++ b/tools/e2e/e2e.go @@ -66,7 +66,7 @@ type Config struct { // - Connect client to channel // - Publish message from HTTP, MQTT, WS and CoAP Adapters. -func Test(conf Config) { +func Test(ctx context.Context, conf Config) { sdkConf := sdk.Config{ UsersURL: fmt.Sprintf("http://%s:%s", conf.Host, usersPort), GroupsURL: fmt.Sprintf("http://%s:%s", conf.Host, groupsPort), @@ -82,51 +82,51 @@ func Test(conf Config) { magenta := color.FgLightMagenta.Render - domainID, token, err := createUser(s, conf) + domainID, token, err := createUser(ctx, s, conf) if err != nil { errExit(fmt.Errorf("unable to create user: %w", err)) } color.Success.Printf("created user with token %s\n", magenta(token)) color.Success.Printf("created domain with ID %s\n", magenta(domainID)) - users, err := createUsers(s, conf, token) + users, err := createUsers(ctx, s, conf, token) if err != nil { errExit(fmt.Errorf("unable to create users: %w", err)) } color.Success.Printf("created users of ids:\n%s\n", magenta(getIDS(users))) - groups, err := createGroups(s, conf, domainID, token) + groups, err := createGroups(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create groups: %w", err)) } color.Success.Printf("created groups of ids:\n%s\n", magenta(getIDS(groups))) - clients, err := createClients(s, conf, domainID, token) + clients, err := createClients(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create clients: %w", err)) } color.Success.Printf("created clients of ids:\n%s\n", magenta(getIDS(clients))) - channels, err := createChannels(s, conf, domainID, token) + channels, err := createChannels(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create channels: %w", err)) } color.Success.Printf("created channels of ids:\n%s\n", magenta(getIDS(channels))) // List users, groups, clients and channels - if err := read(s, conf, domainID, token, users, groups, clients, channels); err != nil { + if err := read(ctx, s, conf, domainID, token, users, groups, clients, channels); err != nil { errExit(fmt.Errorf("unable to read users, groups, clients and channels: %w", err)) } color.Success.Println("viewed users, groups, clients and channels") // Update users, groups, clients and channels - if err := update(s, domainID, token, users, groups, clients, channels); err != nil { + if err := update(ctx, s, domainID, token, users, groups, clients, channels); err != nil { errExit(fmt.Errorf("unable to update users, groups, clients and channels: %w", err)) } color.Success.Println("updated users, groups, clients and channels") // Send messages to channels - if err := messaging(s, conf, domainID, token, clients, channels); err != nil { + if err := messaging(ctx, s, conf, domainID, token, clients, channels); err != nil { errExit(fmt.Errorf("unable to send messages to channels: %w", err)) } color.Success.Println("sent messages to channels") @@ -137,7 +137,7 @@ func errExit(err error) { os.Exit(1) } -func createUser(s sdk.SDK, conf Config) (string, string, error) { +func createUser(ctx context.Context, s sdk.SDK, conf Config) (string, string, error) { user := sdk.User{ FirstName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), LastName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), @@ -149,8 +149,8 @@ func createUser(s sdk.SDK, conf Config) (string, string, error) { Status: sdk.EnabledStatus, Role: "admin", } - ctx := context.Background() - if _, err := s.CreateUser(context.Background(), user, ""); err != nil { + + if _, err := s.CreateUser(ctx, user, ""); err != nil { return "", "", fmt.Errorf("unable to create user: %w", err) } @@ -187,10 +187,9 @@ func createUser(s sdk.SDK, conf Config) (string, string, error) { return domain.ID, token.AccessToken, nil } -func createUsers(s sdk.SDK, conf Config, token string) ([]sdk.User, error) { +func createUsers(ctx context.Context, s sdk.SDK, conf Config, token string) ([]sdk.User, error) { var err error users := []sdk.User{} - ctx := context.Background() for i := uint64(0); i < conf.Num; i++ { user := sdk.User{ @@ -214,10 +213,9 @@ func createUsers(s sdk.SDK, conf Config, token string) ([]sdk.User, error) { return users, nil } -func createGroups(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, error) { +func createGroups(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, error) { var err error groups := []sdk.Group{} - ctx := context.Background() for i := uint64(0); i < conf.Num; i++ { group := sdk.Group{ @@ -235,10 +233,9 @@ func createGroups(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, return groups, nil } -func createClientsInBatch(s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Client, error) { +func createClientsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Client, error) { var err error clients := make([]sdk.Client, num) - ctx := context.Background() for i := uint64(0); i < num; i++ { clients[i] = sdk.Client{ @@ -254,25 +251,25 @@ func createClientsInBatch(s sdk.SDK, conf Config, domainID, token string, num ui return clients, nil } -func createClients(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client, error) { +func createClients(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client, error) { clients := []sdk.Client{} if conf.Num > batchSize { batches := int(conf.Num) / batchSize for i := 0; i < batches; i++ { - ths, err := createClientsInBatch(s, conf, domainID, token, batchSize) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, batchSize) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } clients = append(clients, ths...) } - ths, err := createClientsInBatch(s, conf, domainID, token, conf.Num%uint64(batchSize)) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } clients = append(clients, ths...) } else { - ths, err := createClientsInBatch(s, conf, domainID, token, conf.Num) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } @@ -282,10 +279,9 @@ func createClients(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client return clients, nil } -func createChannelsInBatch(s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Channel, error) { +func createChannelsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Channel, error) { var err error channels := make([]sdk.Channel, num) - ctx := context.Background() for i := uint64(0); i < num; i++ { channels[i] = sdk.Channel{ @@ -300,25 +296,25 @@ func createChannelsInBatch(s sdk.SDK, conf Config, domainID, token string, num u return channels, nil } -func createChannels(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Channel, error) { +func createChannels(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Channel, error) { channels := []sdk.Channel{} if conf.Num > batchSize { batches := int(conf.Num) / batchSize for i := 0; i < batches; i++ { - chs, err := createChannelsInBatch(s, conf, token, domainID, batchSize) + chs, err := createChannelsInBatch(ctx, s, conf, token, domainID, batchSize) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } channels = append(channels, chs...) } - chs, err := createChannelsInBatch(s, conf, domainID, token, conf.Num%uint64(batchSize)) + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } channels = append(channels, chs...) } else { - chs, err := createChannelsInBatch(s, conf, domainID, token, conf.Num) + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } @@ -328,8 +324,7 @@ func createChannels(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Chann return channels, nil } -func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func read(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { for _, user := range users { if _, err := s.User(ctx, user.ID, token); err != nil { return fmt.Errorf("failed to get user %w", err) @@ -382,8 +377,7 @@ func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, grou return nil } -func update(s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func update(ctx context.Context, s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { for _, user := range users { user.FirstName = namesgenerator.Generate() user.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} @@ -545,8 +539,7 @@ func update(s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Gr return nil } -func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func messaging(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, clients []sdk.Client, channels []sdk.Channel) error { for _, c := range clients { for _, channel := range channels { conn := sdk.Connection{ @@ -569,7 +562,7 @@ func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Cli func(num int64, client sdk.Client, channel sdk.Channel) { g.Go(func() error { msg := fmt.Sprintf(msgFormat, num+1, rand.Int()) - return sendHTTPMessage(s, msg, client, channel.ID) + return sendHTTPMessage(ctx, s, msg, client, channel.ID) }) g.Go(func() error { msg := fmt.Sprintf(msgFormat, num+2, rand.Int()) @@ -592,8 +585,7 @@ func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Cli return g.Wait() } -func sendHTTPMessage(s sdk.SDK, msg string, client sdk.Client, chanID string) error { - ctx := context.Background() +func sendHTTPMessage(ctx context.Context, s sdk.SDK, msg string, client sdk.Client, chanID string) error { if err := s.SendMessage(ctx, client.DomainID, chanID, msg, client.Credentials.Secret); err != nil { return fmt.Errorf("HTTP failed to send message from client %s to channel %s: %w", client.ID, chanID, err) } diff --git a/tools/provision/cmd/main.go b/tools/provision/cmd/main.go index 5326ee930..632d8555a 100644 --- a/tools/provision/cmd/main.go +++ b/tools/provision/cmd/main.go @@ -19,8 +19,8 @@ func main() { Short: "provision is provisioning tool for SuperMQ", Long: `Tool for provisioning series of SuperMQ channels and clients and connecting them together. Complete documentation is available at https://docs.supermq.abstractmachines.fr`, - Run: func(_ *cobra.Command, _ []string) { - if err := provision.Provision(pconf); err != nil { + Run: func(cmd *cobra.Command, _ []string) { + if err := provision.Provision(cmd.Context(), pconf); err != nil { log.Fatal(err) } }, diff --git a/tools/provision/provision.go b/tools/provision/provision.go index 91bfe9bd4..80d8dd765 100644 --- a/tools/provision/provision.go +++ b/tools/provision/provision.go @@ -56,8 +56,7 @@ type Config struct { } // Provision - function that does actual provisiong. -func Provision(conf Config) error { - ctx := context.Background() +func Provision(ctx context.Context, conf Config) error { const ( rsaBits = 4096 ttl = "2400h"