From b3e2f41194287651bd1b91ea3368e3cfaad752fd Mon Sep 17 00:00:00 2001 From: b1ackd0t <28790446+rodneyosodo@users.noreply.github.com> Date: Tue, 15 Apr 2025 20:32:09 +0300 Subject: [PATCH] NOISSUE - Add Alarms (#106) * WIP: alarms service * fix(alarms): remove rule entity since it is not stored here Signed-off-by: Rodney Osodo * test(alarms): add tests cases for invalid alarms * feat(alarms): add authorization * feat(alarms): add docker deployment files Signed-off-by: Rodney Osodo * fix: update go mod file * feat(alarms): support filtering by resolved_by, updated_by and severity Signed-off-by: Rodney Osodo * style: fix linter errors Signed-off-by: Rodney Osodo * fix(alarms): provide correct otel naming for create alarm Fixes https://github.com/absmach/magistrala/pull/106#discussion_r2030151971 Signed-off-by: Rodney Osodo * fix(alarms): group routes appropriately Resolves https://github.com/absmach/magistrala/pull/106#discussion_r2030160891 Signed-off-by: Rodney Osodo * fix(alarms): extract alarm id from url path rather than query params Signed-off-by: Rodney Osodo * fix(alarms): add all status to help in decoding Signed-off-by: Rodney Osodo * style(alarms): maintain consistent import as naming for supermq api package Signed-off-by: Rodney Osodo * refactor(alarms): update supermq dependecy to the latest Signed-off-by: Rodney Osodo * fix(alarms): Add domains gRPC service config to alarms service Signed-off-by: Rodney Osodo * test(alarms): all CRUD operations from the service Return empty results instead of nil This standardizes error responses across alarm endpoints to return empty result structs rather than nil. Also renames entityReq to alarmReq and adds HTTP status codes for created/deleted alarms. Signed-off-by: Rodney Osodo * test(alarms): fix failing tests due to introduction of context on sdk Signed-off-by: Rodney Osodo * fix(alarms): remove channel id Signed-off-by: Rodney Osodo * fix(alarms): standardize error handling across CRUD operations Updated error responses to use specific repository errors for consistency Signed-off-by: Rodney Osodo * feat(alarms): add assignment fields to Alarm model and database Introduced AssignedAt and AssignedBy fields to the Alarm struct and updated the database schema accordingly. Enhanced the UpdateAlarm function to handle these new fields, ensuring proper assignment tracking in the alarms system. Signed-off-by: Rodney Osodo * feat(alarms): enhance Alarm model with measurement attributes Updated the Alarm struct to include Measurement, Value, Unit, and Cause fields. Modified the validation logic to ensure these fields are present. Adjusted logging and tracing middleware to reflect the new attributes. Updated database schema and related functions to accommodate these changes, ensuring comprehensive alarm data management. Signed-off-by: Rodney Osodo * feat(alarms): consume events from pubsub for creation of alarms Removed session dependencies from CreateAlarm method and enhanced alarm validation to ensure all required fields are present Signed-off-by: Rodney Osodo * style(alarms): add newline at the end of docker compose Signed-off-by: Rodney Osodo * fix(alarms): Add assignee id and metadata fields when consuming messages Signed-off-by: Rodney Osodo * feat(alarms): add acknowledged field Signed-off-by: Rodney Osodo * feat(alarms): Add threshold value for the specific measurement Signed-off-by: Rodney Osodo * feat(alarms): Add channel, thing, and subtopic fields to Alarm model This change adds required fields for tracking alarm sources and reorganizes alarm-related fields for better grouping. Alarms now track the channel, thing, and subtopic that triggered them, along with domain and rule info. Signed-off-by: Rodney Osodo * test(alarms): add service layer tests Signed-off-by: Rodney Osodo * fix(alarms): consume created at from message rather than creating it Signed-off-by: Rodney Osodo * feat(alarms): ready alarm as a gob encoded object Signed-off-by: Rodney Osodo * fix(alarms): read alarms from alarms queue and remove transformer g Signed-off-by: Rodney Osodo * feat(alarms): update version of supermq Signed-off-by: Rodney Osodo * feat(alarms): add gob transformer Signed-off-by: Rodney Osodo * fix(alarms): rename thing id to client id Signed-off-by: Rodney Osodo * fix(alarms): create alarms stream Signed-off-by: Rodney Osodo * fix(alarms): check on logic to create new alarm create new alarm if severity, status, subtopic changes enhance logging with additional details for alarms management Signed-off-by: Rodney Osodo * remove conusmer and use pubsub Signed-off-by: Rodney Osodo * fix(alarms): use build tags for rabbitmq and nats * fix(alarms): add health and metrics endpoint * fix(magistrala): use supermq as build flags to see version and commit * fix(alarms): use js config * fix(alarms): remove validation when updating an alarm fix authorization too --------- Signed-off-by: Rodney Osodo --- Makefile | 8 +- alarms/alarms.go | 122 +++++ alarms/alarms_test.go | 203 +++++++++ alarms/api/doc.go | 6 + alarms/api/endpoint.go | 105 +++++ alarms/api/requests.go | 36 ++ alarms/api/responses.go | 70 +++ alarms/api/transport.go | 176 ++++++++ alarms/consumer/brokers/brokers_nats.go | 39 ++ alarms/consumer/brokers/brokers_rabbitmq.go | 26 ++ alarms/consumer/consumer.go | 57 +++ alarms/doc.go | 6 + alarms/middleware/authorization.go | 124 +++++ alarms/middleware/doc.go | 6 + alarms/middleware/logging.go | 151 +++++++ alarms/middleware/metrics.go | 74 +++ alarms/middleware/tracing.go | 84 ++++ alarms/mocks/repository.go | 309 +++++++++++++ alarms/mocks/service.go | 304 +++++++++++++ alarms/postgres/alarms.go | 427 ++++++++++++++++++ alarms/postgres/alarms_test.go | 474 ++++++++++++++++++++ alarms/postgres/init.go | 52 +++ alarms/postgres/setup_test.go | 93 ++++ alarms/service.go | 84 ++++ alarms/service_test.go | 253 +++++++++++ alarms/status.go | 70 +++ bootstrap/events/producer/streams.go | 47 +- bootstrap/events/producer/streams_test.go | 6 +- bootstrap/service.go | 1 + bootstrap/service_test.go | 10 +- cli/bootstrap.go | 20 +- cli/bootstrap_test.go | 20 +- cli/consumers.go | 8 +- cli/consumers_test.go | 8 +- cli/provision.go | 31 +- cmd/alarms/main.go | 192 ++++++++ docker/.env | 17 + docker/docker-compose.yaml | 89 ++++ go.mod | 25 +- go.sum | 4 +- pkg/sdk/bootstrap.go | 41 +- pkg/sdk/bootstrap_test.go | 21 +- pkg/sdk/consumers.go | 17 +- pkg/sdk/consumers_test.go | 9 +- pkg/sdk/messages.go | 5 +- pkg/sdk/messages_test.go | 3 +- pkg/sdk/mocks/sdk.go | 381 ++++++++-------- pkg/sdk/sdk.go | 65 +-- provision/api/endpoint.go | 8 +- provision/api/endpoint_test.go | 5 +- provision/api/logging.go | 13 +- provision/mocks/service.go | 87 ++-- provision/service.go | 83 ++-- provision/service_test.go | 15 +- tools/config/.mockery.yaml | 4 + tools/e2e/cmd/main.go | 4 +- tools/e2e/e2e.go | 66 ++- tools/provision/cmd/main.go | 4 +- tools/provision/provision.go | 3 +- 59 files changed, 4170 insertions(+), 501 deletions(-) create mode 100644 alarms/alarms.go create mode 100644 alarms/alarms_test.go create mode 100644 alarms/api/doc.go create mode 100644 alarms/api/endpoint.go create mode 100644 alarms/api/requests.go create mode 100644 alarms/api/responses.go create mode 100644 alarms/api/transport.go create mode 100644 alarms/consumer/brokers/brokers_nats.go create mode 100644 alarms/consumer/brokers/brokers_rabbitmq.go create mode 100644 alarms/consumer/consumer.go create mode 100644 alarms/doc.go create mode 100644 alarms/middleware/authorization.go create mode 100644 alarms/middleware/doc.go create mode 100644 alarms/middleware/logging.go create mode 100644 alarms/middleware/metrics.go create mode 100644 alarms/middleware/tracing.go create mode 100644 alarms/mocks/repository.go create mode 100644 alarms/mocks/service.go create mode 100644 alarms/postgres/alarms.go create mode 100644 alarms/postgres/alarms_test.go create mode 100644 alarms/postgres/init.go create mode 100644 alarms/postgres/setup_test.go create mode 100644 alarms/service.go create mode 100644 alarms/service_test.go create mode 100644 alarms/status.go create mode 100644 cmd/alarms/main.go diff --git a/Makefile b/Makefile index e016c0080..44592f74f 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ MG_DOCKER_IMAGE_NAME_PREFIX ?= ghcr.io/absmach/magistrala BUILD_DIR = build -SERVICES = bootstrap provision re postgres-writer postgres-reader timescale-writer timescale-reader cli +SERVICES = bootstrap provision re postgres-writer postgres-reader timescale-writer timescale-reader cli alarms DOCKERS = $(addprefix docker_,$(SERVICES)) DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES)) CGO_ENABLED ?= 0 @@ -39,9 +39,9 @@ endif define compile_service CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) \ go build -tags $(MG_MESSAGE_BROKER_TYPE) -tags $(MG_ES_TYPE) -ldflags "-s -w \ - -X 'github.com/absmach/magistrala.BuildTime=$(TIME)' \ - -X 'github.com/absmach/magistrala.Version=$(VERSION)' \ - -X 'github.com/absmach/magistrala.Commit=$(COMMIT)'" \ + -X 'github.com/absmach/supermq.BuildTime=$(TIME)' \ + -X 'github.com/absmach/supermq.Version=$(VERSION)' \ + -X 'github.com/absmach/supermq.Commit=$(COMMIT)'" \ -o ${BUILD_DIR}/$(1) cmd/$(1)/main.go endef diff --git a/alarms/alarms.go b/alarms/alarms.go new file mode 100644 index 000000000..ed4cc6647 --- /dev/null +++ b/alarms/alarms.go @@ -0,0 +1,122 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "errors" + "time" + + "github.com/absmach/supermq/pkg/authn" +) + +const SeverityMax uint8 = 100 + +var ErrInvalidSeverity = errors.New("invalid severity. Must be between 0 and 100") + +type Metadata map[string]interface{} + +// Alarm represents an alarm instance. +type Alarm struct { + ID string `json:"id"` + RuleID string `json:"rule_id"` + DomainID string `json:"domain_id"` + ChannelID string `json:"channel_id"` + ClientID string `json:"client_id"` + Subtopic string `json:"subtopic"` + Status Status `json:"status"` + Measurement string `json:"measurement"` + Value string `json:"value"` + Unit string `json:"unit"` + Threshold string `json:"threshold"` + Cause string `json:"cause"` + Severity uint8 `json:"severity"` + AssigneeID string `json:"assignee_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` + AssignedAt time.Time `json:"assigned_at,omitempty"` + AssignedBy string `json:"assigned_by,omitempty"` + AcknowledgedAt time.Time `json:"acknowledged_at,omitempty"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + ResolvedAt time.Time `json:"resolved_at,omitempty"` + ResolvedBy string `json:"resolved_by,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` +} + +type AlarmsPage struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Alarms []Alarm `json:"alarms"` +} + +type PageMetadata struct { + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + DomainID string `json:"domain_id" db:"domain_id"` + ChannelID string `json:"channel_id" db:"channel_id"` + ClientID string `json:"client_id" db:"client_id"` + Subtopic string `json:"subtopic" db:"subtopic"` + RuleID string `json:"rule_id" db:"rule_id"` + Status Status `json:"status" db:"status"` + AssigneeID string `json:"assignee_id" db:"assignee_id"` + Severity uint8 `json:"severity" db:"severity"` + UpdatedBy string `json:"updated_by" db:"updated_by"` + AssignedBy string `json:"assigned_by" db:"assigned_by"` + AcknowledgedBy string `json:"acknowledged_by" db:"acknowledged_by"` + ResolvedBy string `json:"resolved_by" db:"resolved_by"` +} + +func (a Alarm) Validate() error { + if a.RuleID == "" { + return errors.New("rule_id is required") + } + if a.DomainID == "" { + return errors.New("domain_id is required") + } + if a.ChannelID == "" { + return errors.New("channel_id is required") + } + if a.ClientID == "" { + return errors.New("client_id is required") + } + if a.Subtopic == "" { + return errors.New("subtopic is required") + } + if a.Measurement == "" { + return errors.New("measurement is required") + } + if a.Value == "" { + return errors.New("value is required") + } + if a.Unit == "" { + return errors.New("unit is required") + } + if a.Cause == "" { + return errors.New("cause is required") + } + if a.Severity > SeverityMax { + return ErrInvalidSeverity + } + + return nil +} + +// Service specifies an API that must be fulfilled by the domain service. +type Service interface { + CreateAlarm(ctx context.Context, alarm Alarm) error + UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, session authn.Session, id string) (Alarm, error) + ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, session authn.Session, id string) error +} + +type Repository interface { + CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + UpdateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, alarmID, domainID string) (Alarm, error) + ListAlarms(ctx context.Context, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, id string) error +} diff --git a/alarms/alarms_test.go b/alarms/alarms_test.go new file mode 100644 index 000000000..2b1d5ac01 --- /dev/null +++ b/alarms/alarms_test.go @@ -0,0 +1,203 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "fmt" + "testing" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateAlarms(t *testing.T) { + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("rule_id is required"), + }, + { + desc: "missing domain_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("domain_id is required"), + }, + { + desc: "missing channel_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("channel_id is required"), + }, + { + desc: "missing client_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("client_id is required"), + }, + { + desc: "missing subtopic", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("subtopic is required"), + }, + { + desc: "missing measurement", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("measurement is required"), + }, + { + desc: "missing value", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("value is required"), + }, + { + desc: "missing unit", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Cause: "cause", + Severity: 100, + }, + err: errors.New("unit is required"), + }, + { + desc: "missing cause", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Severity: 100, + }, + err: errors.New("cause is required"), + }, + { + desc: "higher severity", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: alarms.SeverityMax + 1, + }, + err: alarms.ErrInvalidSeverity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.alarm.Validate() + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} diff --git a/alarms/api/doc.go b/alarms/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/alarms/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/alarms/api/endpoint.go b/alarms/api/endpoint.go new file mode 100644 index 000000000..f35e7c6d2 --- /dev/null +++ b/alarms/api/endpoint.go @@ -0,0 +1,105 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" +) + +func updateAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.UpdateAlarm(ctx, session, req.Alarm) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func viewAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.ViewAlarm(ctx, session, req.ID) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func listAlarmsEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listAlarmsReq) + if err := req.validate(); err != nil { + return alarmsPageRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmsPageRes{}, svcerr.ErrAuthorization + } + + alarms, err := svc.ListAlarms(ctx, session, req.PageMetadata) + if err != nil { + return alarmsPageRes{}, err + } + + return alarmsPageRes{ + AlarmsPage: alarms, + }, nil + } +} + +func deleteAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + if err := svc.DeleteAlarm(ctx, session, req.ID); err != nil { + return alarmRes{}, err + } + + return alarmRes{deleted: true}, nil + } +} diff --git a/alarms/api/requests.go b/alarms/api/requests.go new file mode 100644 index 000000000..54f677efb --- /dev/null +++ b/alarms/api/requests.go @@ -0,0 +1,36 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + + "github.com/absmach/magistrala/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" +) + +type alarmReq struct { + alarms.Alarm `json:",inline"` +} + +func (req alarmReq) validate() error { + if req.Alarm.ID == "" { + return errors.New("missing alarm id") + } + + return nil +} + +type listAlarmsReq struct { + alarms.PageMetadata +} + +func (req listAlarmsReq) validate() error { + if req.Limit > api.MaxLimitSize || req.Limit < 1 { + return apiutil.ErrLimitSize + } + + return nil +} diff --git a/alarms/api/responses.go b/alarms/api/responses.go new file mode 100644 index 000000000..4a499735d --- /dev/null +++ b/alarms/api/responses.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq" +) + +var ( + _ supermq.Response = (*alarmRes)(nil) + _ supermq.Response = (*alarmsPageRes)(nil) +) + +type alarmRes struct { + alarms.Alarm `json:",inline"` + created bool + deleted bool +} + +func (res alarmRes) Headers() map[string]string { + switch { + case res.created: + return map[string]string{ + "Location": fmt.Sprintf("/%s/alarms/%s", res.DomainID, res.ID), + } + default: + return map[string]string{} + } +} + +func (res alarmRes) Code() int { + switch { + case res.created: + return http.StatusCreated + case res.deleted: + return http.StatusNoContent + default: + return http.StatusOK + } +} + +func (res alarmRes) Empty() bool { + switch { + case res.deleted: + return true + default: + return false + } +} + +type alarmsPageRes struct { + alarms.AlarmsPage `json:",inline"` +} + +func (res alarmsPageRes) Headers() map[string]string { + return map[string]string{} +} + +func (res alarmsPageRes) Code() int { + return http.StatusOK +} + +func (res alarmsPageRes) Empty() bool { + return false +} diff --git a/alarms/api/transport.go b/alarms/api/transport.go new file mode 100644 index 000000000..54e2d9118 --- /dev/null +++ b/alarms/api/transport.go @@ -0,0 +1,176 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "math" + "net/http" + "strings" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.Authentication) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + mux := chi.NewRouter() + + mux.Route("/{domainID}/alarms", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(api.RequestIDMiddleware(idp)) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listAlarmsEndpoint(svc), + decodeListAlarmsReq, + api.EncodeResponse, + opts..., + ), "list_alarms").ServeHTTP) + r.Route("/{alarmID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "get_alarm").ServeHTTP) + r.Put("/", otelhttp.NewHandler(kithttp.NewServer( + updateAlarmEndpoint(svc), + decodeUpdateAlarmReq, + api.EncodeResponse, + opts..., + ), "update_alarm").ServeHTTP) + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "delete_alarm").ServeHTTP) + }) + }) + }) + + mux.Get("/health", supermq.Health("alarms", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeListAlarmsReq(_ context.Context, r *http.Request) (interface{}, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + domainID, err := apiutil.ReadStringQuery(r, "domain_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + channelID, err := apiutil.ReadStringQuery(r, "channel_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + clientID, err := apiutil.ReadStringQuery(r, "client_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + subtopic, err := apiutil.ReadStringQuery(r, "subtopic", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + ruleID, err := apiutil.ReadStringQuery(r, "rule_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + s, err := apiutil.ReadStringQuery(r, api.StatusKey, alarms.All) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + status, err := alarms.ToStatus(s) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assigneeID, err := apiutil.ReadStringQuery(r, "assignee_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + serverity, err := apiutil.ReadNumQuery(r, "severity", uint64(math.MaxUint8)) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + updatedBy, err := apiutil.ReadStringQuery(r, "updated_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assignedBy, err := apiutil.ReadStringQuery(r, "assigned_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + acknowledgedBy, err := apiutil.ReadStringQuery(r, "acknowledged_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + resolvedBy, err := apiutil.ReadStringQuery(r, "resolved_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + return listAlarmsReq{ + PageMetadata: alarms.PageMetadata{ + Offset: offset, + Limit: limit, + DomainID: domainID, + ChannelID: channelID, + ClientID: clientID, + Subtopic: subtopic, + RuleID: ruleID, + Status: status, + AssigneeID: assigneeID, + ResolvedBy: resolvedBy, + Severity: uint8(serverity), + UpdatedBy: updatedBy, + AcknowledgedBy: acknowledgedBy, + AssignedBy: assignedBy, + }, + }, nil +} + +func decodeAlarmReq(_ context.Context, r *http.Request) (interface{}, error) { + return alarmReq{ + Alarm: alarms.Alarm{ + ID: chi.URLParam(r, "alarmID"), + }, + }, nil +} + +func decodeUpdateAlarmReq(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return alarmReq{}, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + req := alarmReq{} + if err := json.NewDecoder(r.Body).Decode(&req.Alarm); err != nil { + return alarmReq{}, errors.Wrap(apiutil.ErrValidation, errors.Wrap(errors.ErrMalformedEntity, err)) + } + + req.Alarm.ID = chi.URLParam(r, "alarmID") + + return req, nil +} diff --git a/alarms/consumer/brokers/brokers_nats.go b/alarms/consumer/brokers/brokers_nats.go new file mode 100644 index 000000000..0f24a26dc --- /dev/null +++ b/alarms/consumer/brokers/brokers_nats.go @@ -0,0 +1,39 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !rabbitmq +// +build !rabbitmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/nats" + "github.com/nats-io/nats.go/jetstream" +) + +const AllTopic = "alarms.>" + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + cfg := jetstream.StreamConfig{ + Name: "alarms", + Description: "SuperMQ stream alarms", + Subjects: []string{"alarms.>"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, + } + pb, err := broker.NewPubSub(ctx, url, logger, broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/consumer/brokers/brokers_rabbitmq.go b/alarms/consumer/brokers/brokers_rabbitmq.go new file mode 100644 index 000000000..cb0960223 --- /dev/null +++ b/alarms/consumer/brokers/brokers_rabbitmq.go @@ -0,0 +1,26 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build rabbitmq +// +build rabbitmq + +package brokers + +import ( + "context" + "log/slog" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/rabbitmq" +) + +const AllTopic = "alarms.#" + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix("alarms")) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/consumer/consumer.go b/alarms/consumer/consumer.go new file mode 100644 index 000000000..8b51979f7 --- /dev/null +++ b/alarms/consumer/consumer.go @@ -0,0 +1,57 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "bytes" + "context" + "encoding/gob" + "log/slog" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" +) + +type handler struct { + svc alarms.Service + logger *slog.Logger +} + +func Newhandler(svc alarms.Service, logger *slog.Logger) messaging.MessageHandler { + return &handler{svc: svc, logger: logger} +} + +func (h handler) Handle(msg *messaging.Message) (err error) { + if msg == nil { + return errors.New("message is empty") + } + if msg.GetPayload() == nil { + return errors.New("message payload is empty") + } + + var alarm alarms.Alarm + if err := gob.NewDecoder(bytes.NewReader(msg.GetPayload())).Decode(&alarm); err != nil { + return err + } + alarm.DomainID = msg.GetDomain() + alarm.ChannelID = msg.GetChannel() + alarm.ClientID = msg.GetPublisher() + alarm.Subtopic = msg.GetSubtopic() + + if alarm.CreatedAt.IsZero() { + alarm.CreatedAt = time.Unix(0, int64(msg.GetCreated())) + } + + if err := alarm.Validate(); err != nil { + return err + } + + return h.svc.CreateAlarm(context.Background(), alarm) +} + +func (h handler) Cancel() error { + return nil +} diff --git a/alarms/doc.go b/alarms/doc.go new file mode 100644 index 000000000..9f7866f33 --- /dev/null +++ b/alarms/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package alarms contains domain concept definitions needed to support +// Alarms service feature, i.e. create, read, update, and delete alarms. +package alarms diff --git a/alarms/middleware/authorization.go b/alarms/middleware/authorization.go new file mode 100644 index 000000000..bad879396 --- /dev/null +++ b/alarms/middleware/authorization.go @@ -0,0 +1,124 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/authn" + smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/policies" +) + +type authorizationMiddleware struct { + svc alarms.Service + authz smqauthz.Authorization +} + +var _ alarms.Service = (*authorizationMiddleware)(nil) + +func NewAuthorizationMiddleware(svc alarms.Service, authz smqauthz.Authorization) alarms.Service { + return &authorizationMiddleware{ + svc: svc, + authz: authz, + } +} + +func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) { + return am.svc.CreateAlarm(ctx, alarm) +} + +func (am *authorizationMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (dba alarms.Alarm, err error) { + // if assignee is present check if assignee is member of domain + + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.AdminPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.Alarm{}, err + } + + if alarm.AssigneeID != "" { + domainUserId := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID) + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: domainUserId, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + }); err != nil { + return alarms.Alarm{}, err + } + } + + return am.svc.UpdateAlarm(ctx, session, alarm) +} + +func (am *authorizationMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + req := smqauthz.PolicyReq{ + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.AdminPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return err + } + + return am.svc.DeleteAlarm(ctx, session, id) +} + +func (am *authorizationMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + if pm.DomainID == "" { + pm.DomainID = session.DomainID + } + + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.AlarmsPage{}, err + } + + return am.svc.ListAlarms(ctx, session, pm) +} + +func (am *authorizationMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + req := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + + if err := am.authz.Authorize(ctx, req); err != nil { + return alarms.Alarm{}, err + } + + return am.svc.ViewAlarm(ctx, session, id) +} diff --git a/alarms/middleware/doc.go b/alarms/middleware/doc.go new file mode 100644 index 000000000..ce4a296d2 --- /dev/null +++ b/alarms/middleware/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides middleware for the alarms service. +// This is logging, metrics, and tracing middleware. +package middleware diff --git a/alarms/middleware/logging.go b/alarms/middleware/logging.go new file mode 100644 index 000000000..c8e265358 --- /dev/null +++ b/alarms/middleware/logging.go @@ -0,0 +1,151 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-chi/chi/v5/middleware" +) + +type loggingMiddleware struct { + logger *slog.Logger + service alarms.Service +} + +var _ alarms.Service = (*loggingMiddleware)(nil) + +func NewLoggingMiddleware(logger *slog.Logger, service alarms.Service) alarms.Service { + return &loggingMiddleware{ + logger: logger, + service: service, + } +} + +func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("rule_id", alarm.RuleID), + slog.String("domain_id", alarm.DomainID), + slog.String("channel_id", alarm.ChannelID), + slog.String("client_id", alarm.ClientID), + slog.String("subtopic", alarm.Subtopic), + slog.String("measurement", alarm.Measurement), + slog.String("value", alarm.Value), + slog.String("unit", alarm.Unit), + slog.String("threshold", alarm.Threshold), + slog.String("cause", alarm.Cause), + slog.Uint64("severity", uint64(alarm.Severity)), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create alarm failed", args...) + return + } + lm.logger.Info("Create alarm completed successfully", args...) + }(time.Now()) + + return lm.service.CreateAlarm(ctx, alarm) +} + +func (lm *loggingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("id", dba.ID), + slog.String("rule_id", dba.RuleID), + slog.String("domain_id", dba.DomainID), + slog.String("channel_id", dba.ChannelID), + slog.String("client_id", dba.ClientID), + slog.String("subtopic", dba.Subtopic), + slog.String("measurement", dba.Measurement), + slog.String("value", dba.Value), + slog.String("unit", dba.Unit), + slog.String("threshold", dba.Threshold), + slog.String("cause", dba.Cause), + slog.Uint64("severity", uint64(dba.Severity)), + slog.String("status", dba.Status.String()), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update alarm failed", args...) + return + } + lm.logger.Info("Update alarm completed successfully", args...) + }(time.Now()) + + return lm.service.UpdateAlarm(ctx, session, alarm) +} + +func (lm *loggingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View alarm failed", args...) + return + } + lm.logger.Info("View alarm completed successfully", args...) + }(time.Now()) + + return lm.service.ViewAlarm(ctx, session, id) +} + +func (lm *loggingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (dbp alarms.AlarmsPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Int("offset", int(pm.Offset)), + slog.Int("limit", int(pm.Limit)), + slog.String("rule_id", pm.RuleID), + slog.String("domain_id", pm.DomainID), + slog.String("channel_id", pm.ChannelID), + slog.String("client_id", pm.ClientID), + slog.String("subtopic", pm.Subtopic), + slog.Uint64("severity", uint64(pm.Severity)), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List alarms failed", args...) + return + } + lm.logger.Info("List alarms completed successfully", args...) + }(time.Now()) + + return lm.service.ListAlarms(ctx, session, pm) +} + +func (lm *loggingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete alarm failed", args...) + return + } + lm.logger.Info("Delete alarm completed successfully", args...) + }(time.Now()) + + return lm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/metrics.go b/alarms/middleware/metrics.go new file mode 100644 index 000000000..3d99baafc --- /dev/null +++ b/alarms/middleware/metrics.go @@ -0,0 +1,74 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-kit/kit/metrics" +) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + service alarms.Service +} + +var _ alarms.Service = (*metricsMiddleware)(nil) + +func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service alarms.Service) alarms.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + service: service, + } +} + +func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + defer func(begin time.Time) { + mm.counter.With("method", "create_alarm").Add(1) + mm.latency.With("method", "create_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.CreateAlarm(ctx, alarm) +} + +func (mm *metricsMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_alarm").Add(1) + mm.latency.With("method", "update_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateAlarm(ctx, session, alarm) +} + +func (mm *metricsMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_alarm").Add(1) + mm.latency.With("method", "get_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewAlarm(ctx, session, id) +} + +func (mm *metricsMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_alarms").Add(1) + mm.latency.With("method", "list_alarms").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ListAlarms(ctx, session, pm) +} + +func (mm *metricsMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "delete_alarm").Add(1) + mm.latency.With("method", "delete_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/tracing.go b/alarms/middleware/tracing.go new file mode 100644 index 000000000..a969af409 --- /dev/null +++ b/alarms/middleware/tracing.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + smqTracing "github.com/absmach/supermq/pkg/tracing" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +type tracingMiddleware struct { + tracer trace.Tracer + svc alarms.Service +} + +var _ alarms.Service = (*tracingMiddleware)(nil) + +func NewTracingMiddleware(tracer trace.Tracer, svc alarms.Service) alarms.Service { + return &tracingMiddleware{ + tracer: tracer, + svc: svc, + } +} + +func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "create_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.CreateAlarm(ctx, alarm) +} + +func (tm *tracingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.UpdateAlarm(ctx, session, alarm) +} + +func (tm *tracingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "get_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewAlarm(ctx, session, id) +} + +func (tm *tracingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_alarms", trace.WithAttributes( + attribute.Int("offset", int(pm.Offset)), + attribute.Int("limit", int(pm.Limit)), + )) + defer span.End() + + return tm.svc.ListAlarms(ctx, session, pm) +} + +func (tm *tracingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "delete_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/mocks/repository.go b/alarms/mocks/repository.go new file mode 100644 index 000000000..12017d769 --- /dev/null +++ b/alarms/mocks/repository.go @@ -0,0 +1,309 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Repository +func (_mock *Repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Repository_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Repository_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Repository_CreateAlarm_Call { + return &Repository_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Repository_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Repository_CreateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_CreateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Repository +func (_mock *Repository) DeleteAlarm(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Repository_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx +// - id +func (_e *Repository_Expecter) DeleteAlarm(ctx interface{}, id interface{}) *Repository_DeleteAlarm_Call { + return &Repository_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, id)} +} + +func (_c *Repository_DeleteAlarm_Call) Run(run func(ctx context.Context, id string)) *Repository_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) Return(err error) *Repository_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Repository_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAlarms provides a mock function for the type Repository +func (_mock *Repository) ListAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type Repository_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx +// - pm +func (_e *Repository_Expecter) ListAlarms(ctx interface{}, pm interface{}) *Repository_ListAlarms_Call { + return &Repository_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, pm)} +} + +func (_c *Repository_ListAlarms_Call) Run(run func(ctx context.Context, pm alarms.PageMetadata)) *Repository_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.PageMetadata)) + }) + return _c +} + +func (_c *Repository_ListAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Repository_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Repository +func (_mock *Repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Repository_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Repository_Expecter) UpdateAlarm(ctx interface{}, alarm interface{}) *Repository_UpdateAlarm_Call { + return &Repository_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, alarm)} +} + +func (_c *Repository_UpdateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Repository +func (_mock *Repository) ViewAlarm(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarmID, domainID) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarmID, domainID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarmID, domainID) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, alarmID, domainID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Repository_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx +// - alarmID +// - domainID +func (_e *Repository_Expecter) ViewAlarm(ctx interface{}, alarmID interface{}, domainID interface{}) *Repository_ViewAlarm_Call { + return &Repository_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, alarmID, domainID)} +} + +func (_c *Repository_ViewAlarm_Call) Run(run func(ctx context.Context, alarmID string, domainID string)) *Repository_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *Repository_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Repository_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Repository_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error)) *Repository_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/mocks/service.go b/alarms/mocks/service.go new file mode 100644 index 000000000..6e984cc47 --- /dev/null +++ b/alarms/mocks/service.go @@ -0,0 +1,304 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Service +func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) error); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Service_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx +// - alarm +func (_e *Service_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Service_CreateAlarm_Call { + return &Service_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Service_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Service_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(alarms.Alarm)) + }) + return _c +} + +func (_c *Service_CreateAlarm_Call) Return(err error) *Service_CreateAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) error) *Service_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Service +func (_mock *Service) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Service_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx +// - session +// - id +func (_e *Service_Expecter) DeleteAlarm(ctx interface{}, session interface{}, id interface{}) *Service_DeleteAlarm_Call { + return &Service_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, session, id)} +} + +func (_c *Service_DeleteAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(string)) + }) + return _c +} + +func (_c *Service_DeleteAlarm_Call) Return(err error) *Service_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAlarms provides a mock function for the type Service +func (_mock *Service) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type Service_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx +// - session +// - pm +func (_e *Service_Expecter) ListAlarms(ctx interface{}, session interface{}, pm interface{}) *Service_ListAlarms_Call { + return &Service_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, session, pm)} +} + +func (_c *Service_ListAlarms_Call) Run(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata)) *Service_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(alarms.PageMetadata)) + }) + return _c +} + +func (_c *Service_ListAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Service_ListAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Service_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Service_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Service +func (_mock *Service) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, session, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Service_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx +// - session +// - alarm +func (_e *Service_Expecter) UpdateAlarm(ctx interface{}, session interface{}, alarm interface{}) *Service_UpdateAlarm_Call { + return &Service_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, session, alarm)} +} + +func (_c *Service_UpdateAlarm_Call) Run(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm)) *Service_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(alarms.Alarm)) + }) + return _c +} + +func (_c *Service_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Service_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Service_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error)) *Service_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Service +func (_mock *Service) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Service_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx +// - session +// - id +func (_e *Service_Expecter) ViewAlarm(ctx interface{}, session interface{}, id interface{}) *Service_ViewAlarm_Call { + return &Service_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, session, id)} +} + +func (_c *Service_ViewAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(authn.Session), args[2].(string)) + }) + return _c +} + +func (_c *Service_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Service_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Service_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error)) *Service_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/postgres/alarms.go b/alarms/postgres/alarms.go new file mode 100644 index 000000000..ddd85e127 --- /dev/null +++ b/alarms/postgres/alarms.go @@ -0,0 +1,427 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "strings" + "time" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" +) + +type repository struct { + db *sqlx.DB +} + +var _ alarms.Repository = (*repository)(nil) + +func NewAlarmsRepo(db *sqlx.DB) alarms.Repository { + return &repository{db: db} +} + +func (r *repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + query := `INSERT INTO alarms (id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, value, unit, threshold, cause, status, severity, assignee_id, metadata, created_at) + VALUES (:id, :rule_id, :domain_id, :channel_id, :client_id, :subtopic, :measurement, :value, :unit, :threshold, :cause, :status, :severity, :assignee_id, :metadata, :created_at) + RETURNING id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, value, unit, threshold, cause, status, severity, assignee_id, metadata, created_at;` + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, query, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + var query []string + var upq string + if alarm.Status != 0 { + query = append(query, "status = :status,") + } + if alarm.AssigneeID != "" { + query = append(query, "assignee_id = :assignee_id,") + } + if !alarm.AssignedAt.IsZero() { + query = append(query, "assigned_at = :assigned_at,") + } + if alarm.AssignedBy != "" { + query = append(query, "assigned_by = :assigned_by,") + } + if alarm.AcknowledgedBy != "" { + query = append(query, "acknowledged_by = :acknowledged_by,") + } + if !alarm.AcknowledgedAt.IsZero() { + query = append(query, "acknowledged_at = :acknowledged_at,") + } + if alarm.ResolvedBy != "" { + query = append(query, "resolved_by = :resolved_by,") + } + if !alarm.ResolvedAt.IsZero() { + query = append(query, "resolved_at = :resolved_at,") + } + if alarm.Metadata != nil { + query = append(query, "metadata = :metadata,") + } + if len(query) > 0 { + upq = strings.Join(query, " ") + } + + q := fmt.Sprintf(`UPDATE alarms SET %s updated_by = :updated_by, updated_at = :updated_at WHERE id = :id + RETURNING id, rule_id, measurement, value, unit, cause, status, domain_id, assignee_id, metadata, created_at, updated_by, updated_at, resolved_by, resolved_at;`, upq) + + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, q, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) ViewAlarm(ctx context.Context, alarmID, domainID string) (alarms.Alarm, error) { + query := `SELECT * FROM alarms WHERE id = :id AND domain_id = :domain_id;` + row, err := r.db.NamedQueryContext(ctx, query, map[string]interface{}{ + "id": alarmID, "domain_id": domainID, + }) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba := dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + alarm, err := toAlarm(dba) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarm, nil +} + +func (r *repository) ListAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + query, err := pageQuery(pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + q := fmt.Sprintf(`SELECT * FROM alarms %s ORDER BY created_at DESC LIMIT :limit OFFSET :offset;`, query) + rows, err := r.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []alarms.Alarm + for rows.Next() { + dba := dbAlarm{} + if err := rows.StructScan(&dba); err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + a, err := toAlarm(dba) + if err != nil { + return alarms.AlarmsPage{}, err + } + + items = append(items, a) + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM alarms %s;`, query) + total, err := postgres.Total(ctx, r.db, q, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarms.AlarmsPage{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Alarms: items, + }, nil +} + +func (r *repository) DeleteAlarm(ctx context.Context, id string) error { + query := `DELETE FROM alarms WHERE id = :id;` + result, err := r.db.NamedExecContext(ctx, query, map[string]interface{}{"id": id}) + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if rowsAffected == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +type dbAlarm struct { + ID string `db:"id"` + RuleID string `db:"rule_id"` + DomainID string `db:"domain_id"` + ChannelID string `db:"channel_id"` + ClientID string `db:"client_id"` + Subtopic string `db:"subtopic"` + Measurement string `db:"measurement"` + Value string `db:"value"` + Unit string `db:"unit"` + Cause string `db:"cause"` + Threshold string `db:"threshold"` + Status alarms.Status `db:"status"` + Severity uint8 `db:"severity"` + AssigneeID string `db:"assignee_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy *string `db:"updated_by,omitempty"` + AssignedAt sql.NullTime `db:"assigned_at,omitempty"` + AssignedBy *string `db:"assigned_by,omitempty"` + AcknowledgedAt sql.NullTime `db:"acknowledged_at,omitempty"` + AcknowledgedBy *string `db:"acknowledged_by,omitempty"` + ResolvedAt sql.NullTime `db:"resolved_at,omitempty"` + ResolvedBy *string `db:"resolved_by,omitempty"` + Metadata []byte `db:"metadata,omitempty"` +} + +func toDBAlarm(a alarms.Alarm) (dbAlarm, error) { + if a.CreatedAt.IsZero() { + a.CreatedAt = time.Now() + } + var updatedBy *string + if a.UpdatedBy != "" { + updatedBy = &a.UpdatedBy + } + var updatedAt sql.NullTime + if a.UpdatedAt != (time.Time{}) { + updatedAt = sql.NullTime{Time: a.UpdatedAt, Valid: true} + } + + var acknowledgedBy *string + if a.AcknowledgedBy != "" { + acknowledgedBy = &a.AcknowledgedBy + } + var acknowledgedAt sql.NullTime + if a.AcknowledgedAt != (time.Time{}) { + acknowledgedAt = sql.NullTime{Time: a.AcknowledgedAt, Valid: true} + } + + var resolvedBy *string + if a.ResolvedBy != "" { + resolvedBy = &a.ResolvedBy + } + var resolvedAt sql.NullTime + if a.ResolvedAt != (time.Time{}) { + resolvedAt = sql.NullTime{Time: a.ResolvedAt, Valid: true} + } + + var assignedBy *string + if a.AssignedBy != "" { + assignedBy = &a.AssignedBy + } + var assignedAt sql.NullTime + if a.AssignedAt != (time.Time{}) { + assignedAt = sql.NullTime{Time: a.AssignedAt, Valid: true} + } + + metadata := []byte("{}") + if len(a.Metadata) > 0 { + b, err := json.Marshal(a.Metadata) + if err != nil { + return dbAlarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + metadata = b + } + + return dbAlarm{ + ID: a.ID, + RuleID: a.RuleID, + DomainID: a.DomainID, + ChannelID: a.ChannelID, + ClientID: a.ClientID, + Subtopic: a.Subtopic, + Measurement: a.Measurement, + Value: a.Value, + Unit: a.Unit, + Cause: a.Cause, + Threshold: a.Threshold, + Status: a.Status, + Severity: a.Severity, + AssigneeID: a.AssigneeID, + CreatedAt: a.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func toAlarm(dbr dbAlarm) (alarms.Alarm, error) { + var updatedBy string + if dbr.UpdatedBy != nil { + updatedBy = *dbr.UpdatedBy + } + var updatedAt time.Time + if dbr.UpdatedAt.Valid { + updatedAt = dbr.UpdatedAt.Time + } + + var assignedBy string + if dbr.AssignedBy != nil { + assignedBy = *dbr.AssignedBy + } + var assignedAt time.Time + if dbr.AssignedAt.Valid { + assignedAt = dbr.AssignedAt.Time + } + + var acknowledgedBy string + if dbr.AcknowledgedBy != nil { + acknowledgedBy = *dbr.AcknowledgedBy + } + var acknowledgedAt time.Time + if dbr.AcknowledgedAt.Valid { + acknowledgedAt = dbr.AcknowledgedAt.Time + } + + var resolvedBy string + if dbr.ResolvedBy != nil { + resolvedBy = *dbr.ResolvedBy + } + var resolvedAt time.Time + if dbr.ResolvedAt.Valid { + resolvedAt = dbr.ResolvedAt.Time + } + + var metadata map[string]interface{} + if len(dbr.Metadata) > 0 { + err := json.Unmarshal(dbr.Metadata, &metadata) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + } + + return alarms.Alarm{ + ID: dbr.ID, + RuleID: dbr.RuleID, + DomainID: dbr.DomainID, + ChannelID: dbr.ChannelID, + ClientID: dbr.ClientID, + Subtopic: dbr.Subtopic, + Measurement: dbr.Measurement, + Value: dbr.Value, + Unit: dbr.Unit, + Threshold: dbr.Threshold, + Cause: dbr.Cause, + Status: dbr.Status, + Severity: dbr.Severity, + AssigneeID: dbr.AssigneeID, + CreatedAt: dbr.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func pageQuery(pm alarms.PageMetadata) (string, error) { + var query []string + if pm.DomainID != "" { + query = append(query, "domain_id = :domain_id") + } + if pm.ChannelID != "" { + query = append(query, "channel_id = :channel_id") + } + if pm.ClientID != "" { + query = append(query, "client_id = :client_id") + } + if pm.Subtopic != "" { + query = append(query, "subtopic = :subtopic") + } + if pm.RuleID != "" { + query = append(query, "rule_id = :rule_id") + } + if pm.Status != alarms.AllStatus { + query = append(query, "status = :status") + } + if pm.AssigneeID != "" { + query = append(query, "assignee_id = :assignee_id") + } + if pm.Severity != math.MaxUint8 { + query = append(query, "severity = :severity") + } + if pm.UpdatedBy != "" { + query = append(query, "updated_by = :updated_by") + } + if pm.ResolvedBy != "" { + query = append(query, "resolved_by = :resolved_by") + } + if pm.AcknowledgedBy != "" { + query = append(query, "acknowledged_by = :acknowledged_by") + } + if pm.AssignedBy != "" { + query = append(query, "assigned_by = :assigned_by") + } + + var emq string + if len(query) > 0 { + emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return emq, nil +} diff --git a/alarms/postgres/alarms_test.go b/alarms/postgres/alarms_test.go new file mode 100644 index 000000000..b435da524 --- /dev/null +++ b/alarms/postgres/alarms_test.go @@ -0,0 +1,474 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func TestCreateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarm, + err: nil, + }, + { + desc: "duplicate alarm", + alarm: alarm, + err: repoerr.ErrConflict, + }, + { + desc: "missing rule id", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + + Metadata: map[string]interface{}{ + "key": make(chan int), + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.CreateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.alarm.RuleID, alarm.RuleID) + require.Equal(t, tc.alarm.Measurement, alarm.Measurement) + require.Equal(t, tc.alarm.Value, alarm.Value) + require.Equal(t, tc.alarm.Unit, alarm.Unit) + require.Equal(t, tc.alarm.Cause, alarm.Cause) + require.Equal(t, tc.alarm.Status, alarm.Status) + require.Equal(t, tc.alarm.DomainID, alarm.DomainID) + require.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + require.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + Status: alarms.ActiveStatus, + DomainID: alarm.DomainID, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: alarm.CreatedAt, + UpdatedAt: time.Now().Local(), + UpdatedBy: generateUUID(&testing.T{}), + ResolvedAt: time.Now().Local(), + ResolvedBy: generateUUID(&testing.T{}), + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + ID: generateUUID(&testing.T{}), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + RuleID: generateUUID(&testing.T{}), + Status: 0, + DomainID: generateUUID(&testing.T{}), + AssigneeID: strings.Repeat("a", 40), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.UpdateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.alarm.Status, alarm.Status) + require.Equal(t, tc.alarm.DomainID, alarm.DomainID) + require.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + require.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestViewAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + domainID: alarm.DomainID, + err: nil, + }, + { + desc: "non existing alarm id", + id: generateUUID(&testing.T{}), + domainID: alarm.DomainID, + err: repoerr.ErrNotFound, + }, + { + desc: "non existing domain id", + id: alarm.ID, + domainID: generateUUID(&testing.T{}), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.ViewAlarm(context.Background(), tc.id, tc.domainID) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, alarm.ID) + require.Equal(t, tc.id, alarm.ID) + }) + } +} + +func TestListAlarms(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + repo := postgres.NewAlarmsRepo(db) + items := make([]alarms.Alarm, 1000) + for i := range 1000 { + items[i] = alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), items[i]) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + items[i].ID = alarm.ID + } + + cases := []struct { + desc string + pm alarms.PageMetadata + response []alarms.Alarm + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + response: items[:10], + err: nil, + }, + { + desc: "offset and limit", + pm: alarms.PageMetadata{ + Offset: 10, + Limit: 50, + }, + response: items[10:60], + err: nil, + }, + { + desc: "empty page", + pm: alarms.PageMetadata{}, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid page", + pm: alarms.PageMetadata{ + Offset: 1000, + Limit: 10, + }, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid assignee id", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + AssigneeID: generateUUID(&testing.T{}), + }, + response: []alarms.Alarm{}, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarms, err := repo.ListAlarms(context.Background(), tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, len(tc.response), len(alarms.Alarms)) + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(&testing.T{}), + RuleID: generateUUID(&testing.T{}), + DomainID: generateUUID(&testing.T{}), + ChannelID: generateUUID(&testing.T{}), + ClientID: generateUUID(&testing.T{}), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(&testing.T{}), + CreatedAt: time.Now().Local(), + Metadata: map[string]interface{}{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + err: nil, + }, + { + desc: "non existing alarm", + id: generateUUID(&testing.T{}), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.DeleteAlarm(context.Background(), tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} + +func generateUUID(t *testing.T) string { + ulid, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + return ulid +} diff --git a/alarms/postgres/init.go b/alarms/postgres/init.go new file mode 100644 index 000000000..6ccb3ab44 --- /dev/null +++ b/alarms/postgres/init.go @@ -0,0 +1,52 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + migrate "github.com/rubenv/sql-migrate" +) + +// Migration of Users service. +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "alarms_01", + // VARCHAR(36) for columns with IDs as UUIDS have a maximum of 36 characters + Up: []string{ + `CREATE TABLE IF NOT EXISTS alarms ( + id VARCHAR(36) PRIMARY KEY, + rule_id VARCHAR(36) NOT NULL CHECK (length(rule_id) > 0), + domain_id VARCHAR(36) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + client_id VARCHAR(36) NOT NULL, + subtopic TEXT NOT NULL, + measurement TEXT NOT NULL, + value TEXT NOT NULL, + unit TEXT NOT NULL, + threshold TEXT NOT NULL, + cause TEXT NOT NULL, + status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0), + severity SMALLINT NOT NULL DEFAULT 0 CHECK (severity >= 0), + assignee_id VARCHAR(36), + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NULL, + updated_by VARCHAR(36) NULL, + assigned_at TIMESTAMPTZ NULL, + assigned_by VARCHAR(36) NULL, + acknowledged_at TIMESTAMPTZ NULL, + acknowledged_by VARCHAR(36) NULL, + resolved_at TIMESTAMPTZ NULL, + resolved_by VARCHAR(36) NULL, + metadata JSONB + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS alarms`, + }, + }, + }, + } +} diff --git a/alarms/postgres/setup_test.go b/alarms/postgres/setup_test.go new file mode 100644 index 000000000..3452d75ef --- /dev/null +++ b/alarms/postgres/setup_test.go @@ -0,0 +1,93 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + apostgres "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = postgres.Setup(dbConfig, *apostgres.Migration()); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/alarms/service.go b/alarms/service.go new file mode 100644 index 000000000..c4bf28304 --- /dev/null +++ b/alarms/service.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "time" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/pkg/authn" +) + +type service struct { + idp supermq.IDProvider + repo Repository +} + +var _ Service = (*service)(nil) + +func NewService(idp supermq.IDProvider, repo Repository) Service { + return &service{ + idp: idp, + repo: repo, + } +} + +func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error { + id, err := s.idp.ID() + if err != nil { + return err + } + alarm.ID = id + if alarm.CreatedAt.IsZero() { + alarm.CreatedAt = time.Now() + } + + if err := alarm.Validate(); err != nil { + return err + } + + pm := PageMetadata{ + Limit: 1, + Offset: 0, + DomainID: alarm.DomainID, + ChannelID: alarm.ChannelID, + ClientID: alarm.ClientID, + Subtopic: alarm.Subtopic, + RuleID: alarm.RuleID, + Severity: alarm.Severity, + Status: alarm.Status, + } + lastAlarms, err := s.repo.ListAlarms(ctx, pm) + if err != nil { + return err + } + + if len(lastAlarms.Alarms) > 0 { + return nil + } + + _, err = s.repo.CreateAlarm(ctx, alarm) + + return err +} + +func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID string) (Alarm, error) { + return s.repo.ViewAlarm(ctx, alarmID, session.DomainID) +} + +func (s *service) ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) { + return s.repo.ListAlarms(ctx, pm) +} + +func (s *service) DeleteAlarm(ctx context.Context, session authn.Session, alarmID string) error { + return s.repo.DeleteAlarm(ctx, alarmID) +} + +func (s *service) UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) { + alarm.UpdatedAt = time.Now() + alarm.UpdatedBy = session.UserID + + return s.repo.UpdateAlarm(ctx, alarm) +} diff --git a/alarms/service_test.go b/alarms/service_test.go new file mode 100644 index 000000000..aba72d339 --- /dev/null +++ b/alarms/service_test.go @@ -0,0 +1,253 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "context" + "fmt" + "testing" + + "github.com/absmach/magistrala/alarms" + "github.com/absmach/magistrala/alarms/mocks" + "github.com/absmach/magistrala/pkg/errors" + "github.com/absmach/supermq/pkg/authn" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var idp = uuid.New() + +func TestCreateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("rule_id is required"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("CreateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + repoCall1 := repo.On("ListAlarms", context.Background(), alarms.PageMetadata{Offset: 0, Limit: 1}).Return(alarms.AlarmsPage{}, tc.err) + err := svc.CreateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestViewAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + domainID: "domain-id", + err: nil, + }, + { + desc: "non existing alarm id", + id: "alarm-id", + domainID: "domain-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.domainID} + repoCall := repo.On("ViewAlarm", context.Background(), tc.id, tc.domainID).Return(alarms.Alarm{}, tc.err) + _, err := svc.ViewAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.alarm.DomainID} + repoCall := repo.On("UpdateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + _, err := svc.UpdateAlarm(context.Background(), s, tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestListAlarms(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + pm alarms.PageMetadata + page alarms.AlarmsPage + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + page: alarms.AlarmsPage{ + Offset: 0, + Limit: 10, + Total: 10, + Alarms: []alarms.Alarm{}, + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.pm.DomainID} + repoCall := repo.On("ListAlarms", context.Background(), tc.pm).Return(tc.page, tc.err) + _, err := svc.ListAlarms(context.Background(), s, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := alarms.NewService(idp, repo) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + err: nil, + }, + { + desc: "non existing alarm", + id: "alarm-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.id} + repoCall := repo.On("DeleteAlarm", context.Background(), tc.id).Return(tc.err) + err := svc.DeleteAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} diff --git a/alarms/status.go b/alarms/status.go new file mode 100644 index 000000000..90a6f7169 --- /dev/null +++ b/alarms/status.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +type Status uint8 + +const ( + ActiveStatus Status = iota + ClearedStatus + + // AllStatus is used for querying purposes to list alarms irrespective + // of their status. It is never stored in the database as the actual + // Alarm status and should always be the largest value in this enumeration. + AllStatus +) + +const ( + Active = "active" + Cleared = "cleared" + Unknown = "unknown" + All = "all" +) + +// String converts alarm status to string literal. +func (s Status) String() string { + switch s { + case ActiveStatus: + return Active + case ClearedStatus: + return Cleared + default: + return Unknown + } +} + +// ToStatus converts string value to a valid Alarm status. +func ToStatus(status string) (Status, error) { + switch status { + case Active: + return ActiveStatus, nil + case Cleared: + return ClearedStatus, nil + case All: + return AllStatus, nil + default: + return Status(0), svcerr.ErrInvalidStatus + } +} + +// Custom Marshaller for Alarm. +func (s Status) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +// Custom Unmarshaler for Alarm. +func (s *Status) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToStatus(str) + *s = val + + return err +} diff --git a/bootstrap/events/producer/streams.go b/bootstrap/events/producer/streams.go index 750550854..5ea708bcf 100644 --- a/bootstrap/events/producer/streams.go +++ b/bootstrap/events/producer/streams.go @@ -14,22 +14,21 @@ import ( var _ bootstrap.Service = (*eventStore)(nil) const ( - streamPrefix = ".bootstrap" - addStream = streamPrefix + "add" - viewStream = streamPrefix + "view" - updateStream = streamPrefix + "update" - listStream = streamPrefix + "list" - bootstrapStream = streamPrefix + "bootstrap" - removeStream = streamPrefix + "remove" - updateCertStream = streamPrefix + "update_cert" - updateConnectionsStream = streamPrefix + "update_connections" - changeStateStream = streamPrefix + "change_state" - - connectClientHandlerStream = streamPrefix + "connect_client_handler" - disconnectClientHandlerStream = streamPrefix + "disconnect_client_handler" - removeConfigHandlerStream = streamPrefix + "remove_config_handler" - removeChannelHandlerStream = streamPrefix + "remove_channel_handler" - updateChannelHandlerStream = streamPrefix + "update_channel_handler" + magistralaPrefix = "magistrala." + createStream = magistralaPrefix + configCreate + viewStream = magistralaPrefix + configView + listStream = magistralaPrefix + configList + updateStream = magistralaPrefix + configUpdate + removeStream = magistralaPrefix + configRemove + updateCertStream = magistralaPrefix + certUpdate + updateConnectionsStream = magistralaPrefix + clientUpdateConnections + removeHandlerStream = magistralaPrefix + configHandlerRemove + bootstrapStream = magistralaPrefix + clientBootstrap + stateChangeStream = magistralaPrefix + clientStateChange + connectStream = magistralaPrefix + clientConnect + disconnectStream = magistralaPrefix + clientDisconnect + updateHandlerStream = magistralaPrefix + channelUpdateHandler + removeChannelHandlerStream = magistralaPrefix + channelHandlerRemove ) type eventStore struct { @@ -56,7 +55,7 @@ func (es *eventStore) Add(ctx context.Context, session smqauthn.Session, token s saved, configCreate, } - if err := es.Publish(ctx, addStream, ev); err != nil { + if err := es.Publish(ctx, createStream, ev); err != nil { return saved, err } @@ -72,7 +71,7 @@ func (es *eventStore) View(ctx context.Context, session smqauthn.Session, id str cfg, configView, } - if err := es.Publish(ctx, viewStream, ev); err != nil { + if err := es.Publish(ctx, configView, ev); err != nil { return cfg, err } @@ -88,7 +87,7 @@ func (es *eventStore) Update(ctx context.Context, session smqauthn.Session, cfg cfg, configUpdate, } - return es.Publish(ctx, updateStream, ev) + return es.Publish(ctx, configUpdate, ev) } func (es eventStore) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { @@ -186,7 +185,7 @@ func (es *eventStore) ChangeState(ctx context.Context, session smqauthn.Session, state: state, } - return es.Publish(ctx, changeStateStream, ev) + return es.Publish(ctx, stateChangeStream, ev) } func (es *eventStore) RemoveConfigHandler(ctx context.Context, id string) error { @@ -199,7 +198,7 @@ func (es *eventStore) RemoveConfigHandler(ctx context.Context, id string) error operation: configHandlerRemove, } - return es.Publish(ctx, removeConfigHandlerStream, ev) + return es.Publish(ctx, removeHandlerStream, ev) } func (es *eventStore) RemoveChannelHandler(ctx context.Context, id string) error { @@ -224,7 +223,7 @@ func (es *eventStore) UpdateChannelHandler(ctx context.Context, channel bootstra channel, } - return es.Publish(ctx, updateChannelHandlerStream, ev) + return es.Publish(ctx, updateStream, ev) } func (es *eventStore) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { @@ -237,7 +236,7 @@ func (es *eventStore) ConnectClientHandler(ctx context.Context, channelID, clien channelID: channelID, } - return es.Publish(ctx, connectClientHandlerStream, ev) + return es.Publish(ctx, connectStream, ev) } func (es *eventStore) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { @@ -250,5 +249,5 @@ func (es *eventStore) DisconnectClientHandler(ctx context.Context, channelID, cl channelID: channelID, } - return es.Publish(ctx, disconnectClientHandlerStream, ev) + return es.Publish(ctx, disconnectStream, ev) } diff --git a/bootstrap/events/producer/streams_test.go b/bootstrap/events/producer/streams_test.go index cf9d43d0d..fe83cea10 100644 --- a/bootstrap/events/producer/streams_test.go +++ b/bootstrap/events/producer/streams_test.go @@ -192,7 +192,7 @@ func TestAdd(t *testing.T) { lastID := "0" for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := tv.sdk.On("Client", tc.config.ClientID, tc.domainID, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, errors.NewSDKError(tc.clientErr)) + sdkCall := tv.sdk.On("Client", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, errors.NewSDKError(tc.clientErr)) repoCall := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything).Return(tc.config.Channels, tc.listErr) repoCall1 := tv.boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) @@ -475,7 +475,7 @@ func TestUpdateConnections(t *testing.T) { lastID := "0" for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := tv.sdk.On("Channel", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + sdkCall := tv.sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.configID).Return(config, tc.retrieveErr) repoCall1 := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything, mock.Anything).Return(config.Channels, tc.listErr) repoCall2 := tv.boot.On("UpdateConnections", context.Background(), tc.domainID, tc.configID, mock.Anything, tc.connections).Return(tc.updateErr) @@ -1054,7 +1054,7 @@ func TestChangeState(t *testing.T) { for _, tc := range cases { tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(config, tc.retrieveErr) - sdkCall1 := tv.sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.NewSDKError(tc.connectErr)) + sdkCall1 := tv.sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.NewSDKError(tc.connectErr)) repoCall1 := tv.boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) err := tv.svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) diff --git a/bootstrap/service.go b/bootstrap/service.go index 9e27d70f2..ab3709db9 100644 --- a/bootstrap/service.go +++ b/bootstrap/service.go @@ -144,6 +144,7 @@ func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, to if err != nil { return Config{}, errors.Wrap(errCheckChannels, err) } + cfg.Channels, err = bs.connectionChannels(ctx, toConnect, bs.toIDList(existing), session.DomainID, token) if err != nil { return Config{}, errors.Wrap(errConnectionChannels, err) diff --git a/bootstrap/service_test.go b/bootstrap/service_test.go index 9ba5b5984..48b78ed4b 100644 --- a/bootstrap/service_test.go +++ b/bootstrap/service_test.go @@ -152,9 +152,9 @@ func TestAdd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} - repoCall := sdk.On("Client", tc.config.ClientID, mock.Anything, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, tc.clientErr) - repoCall1 := sdk.On("CreateClient", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Client{}, tc.createClientErr) - repoCall2 := sdk.On("DeleteClient", tc.config.ClientID, tc.domainID, tc.token).Return(tc.deleteClientErr) + repoCall := sdk.On("Client", mock.Anything, tc.config.ClientID, mock.Anything, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, tc.clientErr) + repoCall1 := sdk.On("CreateClient", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Client{}, tc.createClientErr) + repoCall2 := sdk.On("DeleteClient", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(tc.deleteClientErr) repoCall3 := boot.On("ListExisting", context.Background(), tc.domainID, mock.Anything).Return(tc.config.Channels, tc.listExistingErr) repoCall4 := boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) _, err := svc.Add(context.Background(), tc.session, tc.token, tc.config) @@ -434,7 +434,7 @@ func TestUpdateConnections(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} - sdkCall := sdk.On("Channel", mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + sdkCall := sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) repoCall1 := boot.On("ListExisting", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(c.Channels, tc.listErr) repoCall2 := boot.On("UpdateConnections", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.updateErr) @@ -939,7 +939,7 @@ func TestChangeState(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) - sdkCall := sdk.On("ConnectClients", mock.Anything, mock.Anything, []string{"Publish", "Subscribe"}, mock.Anything, tc.token).Return(tc.connectErr) + sdkCall := sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, []string{"Publish", "Subscribe"}, mock.Anything, tc.token).Return(tc.connectErr) repoCall1 := boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) err := svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) diff --git a/cli/bootstrap.go b/cli/bootstrap.go index 69daac7a6..1acd80cd8 100644 --- a/cli/bootstrap.go +++ b/cli/bootstrap.go @@ -27,7 +27,7 @@ var cmdBootstrap = []cobra.Command{ return } - id, err := sdk.AddBootstrap(cfg, args[1], args[2]) + id, err := sdk.AddBootstrap(cmd.Context(), cfg, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -54,7 +54,7 @@ var cmdBootstrap = []cobra.Command{ Name: Name, } if args[0] == "all" { - l, err := sdk.Bootstraps(pageMetadata, args[1], args[2]) + l, err := sdk.Bootstraps(cmd.Context(), pageMetadata, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -63,7 +63,7 @@ var cmdBootstrap = []cobra.Command{ return } - c, err := sdk.ViewBootstrap(args[0], args[1], args[2]) + c, err := sdk.ViewBootstrap(cmd.Context(), args[0], args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -92,7 +92,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.UpdateBootstrap(cfg, args[1], args[2]); err != nil { + if err := sdk.UpdateBootstrap(cmd.Context(), cfg, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -106,7 +106,7 @@ var cmdBootstrap = []cobra.Command{ logErrorCmd(*cmd, err) return } - if err := sdk.UpdateBootstrapConnection(args[1], ids, args[3], args[4]); err != nil { + if err := sdk.UpdateBootstrapConnection(cmd.Context(), args[1], ids, args[3], args[4]); err != nil { logErrorCmd(*cmd, err) return } @@ -115,7 +115,7 @@ var cmdBootstrap = []cobra.Command{ return } if args[0] == "certs" { - cfg, err := sdk.UpdateBootstrapCerts(args[0], args[1], args[2], args[3], args[4], args[5]) + cfg, err := sdk.UpdateBootstrapCerts(cmd.Context(), args[0], args[1], args[2], args[3], args[4], args[5]) if err != nil { logErrorCmd(*cmd, err) return @@ -137,7 +137,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.RemoveBootstrap(args[0], args[1], args[2]); err != nil { + if err := sdk.RemoveBootstrap(cmd.Context(), args[0], args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -156,7 +156,7 @@ var cmdBootstrap = []cobra.Command{ return } if args[0] == "secure" { - c, err := sdk.BootstrapSecure(args[1], args[2], args[3]) + c, err := sdk.BootstrapSecure(cmd.Context(), args[1], args[2], args[3]) if err != nil { logErrorCmd(*cmd, err) return @@ -165,7 +165,7 @@ var cmdBootstrap = []cobra.Command{ logJSONCmd(*cmd, c) return } - c, err := sdk.Bootstrap(args[0], args[1]) + c, err := sdk.Bootstrap(cmd.Context(), args[0], args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -190,7 +190,7 @@ var cmdBootstrap = []cobra.Command{ return } - if err := sdk.Whitelist(cfg.ClientID, cfg.State, args[1], args[2]); err != nil { + if err := sdk.Whitelist(cmd.Context(), cfg.ClientID, cfg.State, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cli/bootstrap_test.go b/cli/bootstrap_test.go index f549c9b01..3b28677cd 100644 --- a/cli/bootstrap_test.go +++ b/cli/bootstrap_test.go @@ -102,7 +102,7 @@ func TestCreateBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("AddBootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.id, tc.sdkErr) + sdkCall := sdkMock.On("AddBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.id, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) switch tc.logType { @@ -199,8 +199,8 @@ func TestGetBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("ViewBootstrap", tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr) - sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr) + sdkCall := sdkMock.On("ViewBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) @@ -284,7 +284,7 @@ func TestRemoveBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("RemoveBootstrap", tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr) + sdkCall := sdkMock.On("RemoveBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) switch tc.logType { @@ -443,9 +443,9 @@ func TestUpdateBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { var boot mgsdk.BootstrapConfig - sdkCall := sdkMock.On("UpdateBootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) - sdkCall1 := sdkMock.On("UpdateBootstrapConnection", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) - sdkCall2 := sdkMock.On("UpdateBootstrapCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall := sdkMock.On("UpdateBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall1 := sdkMock.On("UpdateBootstrapConnection", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall2 := sdkMock.On("UpdateBootstrapCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{updCmd}, tc.args...)...) switch tc.logType { @@ -527,7 +527,7 @@ func TestWhitelistConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("Whitelist", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.sdkErr) + sdkCall := sdkMock.On("Whitelist", mock.Anything, mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{whitelistCmd}, tc.args...)...) switch tc.logType { case okLog: @@ -613,8 +613,8 @@ func TestBootstrapConfigCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("BootstrapSecure", mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) - sdkCall1 := sdkMock.On("Bootstrap", mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall := sdkMock.On("BootstrapSecure", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{bootStrapCmd}, tc.args...)...) switch tc.logType { case entityLog: diff --git a/cli/consumers.go b/cli/consumers.go index 64bcd6009..99ef471df 100644 --- a/cli/consumers.go +++ b/cli/consumers.go @@ -19,7 +19,7 @@ var cmdSubscription = []cobra.Command{ return } - id, err := sdk.CreateSubscription(args[0], args[1], args[2]) + id, err := sdk.CreateSubscription(cmd.Context(), args[0], args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -46,7 +46,7 @@ var cmdSubscription = []cobra.Command{ Contact: Contact, } if args[0] == "all" { - sub, err := sdk.ListSubscriptions(pageMetadata, args[1]) + sub, err := sdk.ListSubscriptions(cmd.Context(), pageMetadata, args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -55,7 +55,7 @@ var cmdSubscription = []cobra.Command{ return } - c, err := sdk.ViewSubscription(args[0], args[1]) + c, err := sdk.ViewSubscription(cmd.Context(), args[0], args[1]) if err != nil { logErrorCmd(*cmd, err) return @@ -74,7 +74,7 @@ var cmdSubscription = []cobra.Command{ return } - if err := sdk.DeleteSubscription(args[0], args[1]); err != nil { + if err := sdk.DeleteSubscription(cmd.Context(), args[0], args[1]); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cli/consumers_test.go b/cli/consumers_test.go index 0cb27696c..09a04f485 100644 --- a/cli/consumers_test.go +++ b/cli/consumers_test.go @@ -81,7 +81,7 @@ func TestCreateSubscriptionCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("CreateSubscription", tc.args[0], tc.args[1], tc.args[2]).Return(tc.id, tc.sdkErr) + sdkCall := sdkMock.On("CreateSubscription", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.id, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) switch tc.logType { @@ -168,8 +168,8 @@ func TestGetSubscriptionsCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("ViewSubscription", tc.args[0], tc.args[1]).Return(tc.subscription, tc.sdkErr) - sdkCall1 := sdkMock.On("ListSubscriptions", mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr) + sdkCall := sdkMock.On("ViewSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.subscription, tc.sdkErr) + sdkCall1 := sdkMock.On("ListSubscriptions", mock.Anything, mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) @@ -249,7 +249,7 @@ func TestRemoveSubscriptionCmd(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - sdkCall := sdkMock.On("DeleteSubscription", tc.args[0], tc.args[1]).Return(tc.sdkErr) + sdkCall := sdkMock.On("DeleteSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.sdkErr) out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) switch tc.logType { diff --git a/cli/provision.go b/cli/provision.go index 5a0e1b121..282474303 100644 --- a/cli/provision.go +++ b/cli/provision.go @@ -4,7 +4,6 @@ package cli import ( - "context" "encoding/csv" "encoding/json" "errors" @@ -54,7 +53,7 @@ var cmdProvision = []cobra.Command{ return } - clients, err = sdk.CreateClients(context.Background(), clients, args[1], args[2]) + clients, err = sdk.CreateClients(cmd.Context(), clients, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -81,7 +80,7 @@ var cmdProvision = []cobra.Command{ var chs []smqsdk.Channel for _, c := range channels { - c, err = sdk.CreateChannel(context.Background(), c, args[1], args[2]) + c, err = sdk.CreateChannel(cmd.Context(), c, args[1], args[2]) if err != nil { logErrorCmd(*cmd, err) return @@ -109,7 +108,7 @@ var cmdProvision = []cobra.Command{ return } for _, conn := range connIDs { - if err := sdk.Connect(context.Background(), conn, args[1], args[2]); err != nil { + if err := sdk.Connect(cmd.Context(), conn, args[1], args[2]); err != nil { logErrorCmd(*cmd, err) return } @@ -146,13 +145,13 @@ var cmdProvision = []cobra.Command{ }, Status: smqsdk.EnabledStatus, } - user, err := sdk.CreateUser(context.Background(), user, "") + user, err := sdk.CreateUser(cmd.Context(), user, "") if err != nil { logErrorCmd(*cmd, err) return } - ut, err := sdk.CreateToken(context.Background(), smqsdk.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) + ut, err := sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) if err != nil { logErrorCmd(*cmd, err) return @@ -163,13 +162,13 @@ var cmdProvision = []cobra.Command{ Name: fmt.Sprintf("%s-domain", name), Status: smqsdk.EnabledStatus, } - domain, err = sdk.CreateDomain(context.Background(), domain, ut.AccessToken) + domain, err = sdk.CreateDomain(cmd.Context(), domain, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return } - ut, err = sdk.CreateToken(context.Background(), smqsdk.Login{Username: user.Email, Password: user.Credentials.Secret}) + ut, err = sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Email, Password: user.Credentials.Secret}) if err != nil { logErrorCmd(*cmd, err) return @@ -184,7 +183,7 @@ var cmdProvision = []cobra.Command{ clients = append(clients, t) } - clients, err = sdk.CreateClients(context.Background(), clients, domain.ID, ut.AccessToken) + clients, err = sdk.CreateClients(cmd.Context(), clients, domain.ID, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return @@ -196,7 +195,7 @@ var cmdProvision = []cobra.Command{ Name: fmt.Sprintf("%s-channel-%d", name, i), Status: smqsdk.EnabledStatus, } - c, err = sdk.CreateChannel(context.Background(), c, domain.ID, ut.AccessToken) + c, err = sdk.CreateChannel(cmd.Context(), c, domain.ID, ut.AccessToken) if err != nil { logErrorCmd(*cmd, err) return @@ -211,7 +210,7 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[0].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } @@ -221,7 +220,7 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[0].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } @@ -231,21 +230,21 @@ var cmdProvision = []cobra.Command{ ClientIDs: []string{clients[1].ID}, Types: []string{PublishType, SubscribeType}, } - if err := sdk.Connect(context.Background(), conIDs, domain.ID, ut.AccessToken); err != nil { + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { logErrorCmd(*cmd, err) return } // send message to test connectivity - if err := sdk.SendMessage(context.Background(), domain.ID, channels[0].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } - if err := sdk.SendMessage(context.Background(), domain.ID, channels[0].ID, clients[1].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[1].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } - if err := sdk.SendMessage(context.Background(), domain.ID, channels[1].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[1].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { logErrorCmd(*cmd, err) return } diff --git a/cmd/alarms/main.go b/cmd/alarms/main.go new file mode 100644 index 000000000..23b0d91b0 --- /dev/null +++ b/cmd/alarms/main.go @@ -0,0 +1,192 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + + "github.com/absmach/magistrala/alarms" + httpAPI "github.com/absmach/magistrala/alarms/api" + "github.com/absmach/magistrala/alarms/consumer" + "github.com/absmach/magistrala/alarms/consumer/brokers" + "github.com/absmach/magistrala/alarms/middleware" + alarmsRepo "github.com/absmach/magistrala/alarms/postgres" + "github.com/absmach/magistrala/pkg/prometheus" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/authn/authsvc" + authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/grpcclient" + "github.com/absmach/supermq/pkg/jaeger" + "github.com/absmach/supermq/pkg/messaging" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v11" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "alarms" + envPrefixDB = "MG_ALARMS_DB_" + envPrefixHTTP = "MG_ALARMS_HTTP_" + envPrefixAuth = "SMQ_AUTH_GRPC_" + defDB = "alarms" + defSvcHTTPPort = "8050" + envPrefixDomains = "SMQ_DOMAINS_GRPC_" +) + +type config struct { + LogLevel string `env:"MG_ALARMS_LOG_LEVEL" envDefault:"info"` + BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + InstanceID string `env:"MG_ALARMS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err.Error()) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + dbConfig := postgres.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + } + + db, err := postgres.Setup(dbConfig, *alarmsRepo.Migration()) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + repo := alarmsRepo.NewAlarmsRepo(db) + + authConfig := grpcclient.Config{} + if err := env.ParseWithOptions(&authConfig, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) + exitCode = 1 + return + } + authn, authnClient, err := authsvc.NewAuthentication(ctx, authConfig) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authConfig, domAuthz) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer authzHandler.Close() + + logger.Info("AuthZ successfully connected to auth gRPC server " + authzHandler.Secure()) + + idp := uuid.New() + + svc := alarms.NewService(idp, repo) + + svc = middleware.NewAuthorizationMiddleware(svc, authz) + svc = middleware.NewLoggingMiddleware(logger, svc) + counter, latency := prometheus.MakeMetrics("alarms", "api") + svc = middleware.NewMetricsMiddleware(counter, latency, svc) + svc = middleware.NewTracingMiddleware(tracer, svc) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, authn), logger) + + pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + exitCode = 1 + return + } + defer pubSub.Close() + pubSub = brokerstracing.NewPubSub(httpServerConfig, tracer, pubSub) + + consumer := consumer.Newhandler(svc, logger) + + subCfg := messaging.SubscriberConfig{ + ID: svcName, + Topic: brokers.AllTopic, + DeliveryPolicy: messaging.DeliverAllPolicy, + Handler: consumer, + } + if err := pubSub.Subscribe(ctx, subCfg); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to message broker: %s", err)) + exitCode = 1 + + return + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("billing service terminated: %s", err)) + } +} diff --git a/docker/.env b/docker/.env index 2d8e76b0f..adaf6664c 100644 --- a/docker/.env +++ b/docker/.env @@ -119,6 +119,23 @@ MG_EMAIL_FROM_ADDRESS=from@example.com MG_EMAIL_FROM_NAME=Example MG_EMAIL_TEMPLATE=email.tmpl +### Alarms +MG_ALARMS_LOG_LEVEL=debug +MG_ALARMS_HTTP_HOST=alarms +MG_ALARMS_HTTP_PORT=8050 +MG_ALARMS_HTTP_SERVER_CERT= +MG_ALARMS_HTTP_SERVER_KEY= +MG_ALARMS_DB_HOST=alarms-db +MG_ALARMS_DB_PORT=5432 +MG_ALARMS_DB_USER=magistrala +MG_ALARMS_DB_PASS=magistrala +MG_ALARMS_DB_NAME=alarms +MG_ALARMS_DB_SSL_MODE=disable +MG_ALARMS_DB_SSL_CERT= +MG_ALARMS_DB_SSL_KEY= +MG_ALARMS_DB_SSL_ROOT_CERT= +MG_ALARMS_INSTANCE_ID= + ### Certs SMQ_ADDONS_CERTS_PATH_PREFIX=./ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1301957e9..a6abe34e3 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -22,6 +22,7 @@ volumes: magistrala-ui-backend-db-volume: magistrala-re-db-volume: magistrala-auth-redis-volume: + magistrala-alarms-db-volume: services: ui: @@ -214,3 +215,91 @@ services: bind: create_host_path: true - ./re_config.toml:/config.toml + + alarms-db: + image: postgres:16.2-alpine + container_name: magistrala-alarms-db + restart: on-failure + command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_ALARMS_DB_USER} + POSTGRES_PASSWORD: ${MG_ALARMS_DB_PASS} + POSTGRES_DB: ${MG_ALARMS_DB_NAME} + ports: + - 6019:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-alarms-db-volume:/var/lib/postgresql/data + + alarms: + image: ghcr.io/absmach/magistrala/alarms:${MG_RELEASE_TAG} + container_name: magistrala-alarms + depends_on: + - alarms-db + restart: on-failure + environment: + MG_ALARMS_LOG_LEVEL: ${MG_ALARMS_LOG_LEVEL} + MG_ALARMS_HTTP_PORT: ${MG_ALARMS_HTTP_PORT} + MG_ALARMS_HTTP_HOST: ${MG_ALARMS_HTTP_HOST} + MG_ALARMS_HTTP_SERVER_CERT: ${MG_ALARMS_HTTP_SERVER_CERT} + MG_ALARMS_HTTP_SERVER_KEY: ${MG_ALARMS_HTTP_SERVER_KEY} + MG_ALARMS_DB_HOST: ${MG_ALARMS_DB_HOST} + MG_ALARMS_DB_PORT: ${MG_ALARMS_DB_PORT} + MG_ALARMS_DB_USER: ${MG_ALARMS_DB_USER} + MG_ALARMS_DB_PASS: ${MG_ALARMS_DB_PASS} + MG_ALARMS_DB_NAME: ${MG_ALARMS_DB_NAME} + MG_ALARMS_DB_SSL_MODE: ${MG_ALARMS_DB_SSL_MODE} + MG_ALARMS_DB_SSL_CERT: ${MG_ALARMS_DB_SSL_CERT} + MG_ALARMS_DB_SSL_KEY: ${MG_ALARMS_DB_SSL_KEY} + MG_ALARMS_DB_SSL_ROOT_CERT: ${MG_ALARMS_DB_SSL_ROOT_CERT} + SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} + SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} + SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} + SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} + SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} + SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} + SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} + SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ALARMS_INSTANCE_ID: ${MG_ALARMS_INSTANCE_ID} + ports: + - ${MG_ALARMS_HTTP_PORT}:${MG_ALARMS_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + # Auth gRPC client certificates + - type: bind + source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + target: /domains-grpc-client${SMQ_DOMAINS_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + target: /domains-grpc-server-ca${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true diff --git a/go.mod b/go.mod index 3f6edb06d..0b6a9202e 100644 --- a/go.mod +++ b/go.mod @@ -5,17 +5,19 @@ go 1.24.2 require ( github.com/0x6flab/namegenerator v1.4.0 github.com/absmach/callhome v0.14.0 - github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af + github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0 github.com/authzed/authzed-go v1.3.1-0.20250320210445-0cde0d8c71e2 github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8 github.com/caarlos0/env/v11 v11.3.1 github.com/eclipse/paho.mqtt.golang v1.5.0 + github.com/fatih/color v1.18.0 github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba github.com/go-chi/chi/v5 v5.2.1 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid/v5 v5.3.2 github.com/gookit/color v1.5.4 github.com/gorilla/websocket v1.5.3 + github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f github.com/ivanpirog/coloredcobra v1.0.1 github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgtype v1.14.4 @@ -41,21 +43,13 @@ require ( moul.io/http2curl v1.0.0 ) -require ( - filippo.io/edwards25519 v1.1.0 // indirect - github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect - github.com/lib/pq v1.10.9 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect - github.com/moby/sys/user v0.3.0 // indirect -) - require ( dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect + github.com/absmach/certs v0.0.0-20250303232207-ef00d309ca02 // indirect github.com/absmach/senml v1.0.7 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect @@ -69,7 +63,6 @@ require ( github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect - github.com/fatih/color v1.18.0 github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/fxamacker/cbor/v2 v2.8.0 // indirect @@ -78,12 +71,13 @@ require ( github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect - github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -91,12 +85,15 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jzelinskie/stringz v0.0.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/sys/user v0.3.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/nats-io/nats.go v1.41.0 // indirect + github.com/nats-io/nats.go v1.41.0 github.com/nats-io/nkeys v0.4.10 // indirect github.com/nats-io/nuid v1.0.1 // indirect github.com/oklog/ulid/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 994fd5e11..bd2476b71 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/absmach/mgate v0.4.5 h1:l6RmrEsR9jxkdb9WHUSecmT0HA41TkZZQVffFfUAIfI= github.com/absmach/mgate v0.4.5/go.mod h1:IvRIHZexZPEIAPmmaJF0L5DY2ERjj+GxRGitOW4s6qo= github.com/absmach/senml v1.0.7 h1:XLvpw0qxbP2QhOz7KLM2ZRar+vSCpSG/0o0kEvWx3No= github.com/absmach/senml v1.0.7/go.mod h1:3bRIiNc8hq7l3auMs8gQrpsM5hHy7iDuiLILrf/+MfA= -github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af h1:FKVu3CX1V+8mob0hvfD0havIWXp9Wcv8nHJOeNhWmUE= -github.com/absmach/supermq v0.16.1-0.20250411132829-0e571d1905af/go.mod h1:dJqO3luvt+zLVuDhxOdsRRCuv945F86Nf1BXxFro+nU= +github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0 h1:vPTeIjQPCgDCMGSA9pICZNcTNpz8fRl5CBlpVewMCwY= +github.com/absmach/supermq v0.16.1-0.20250411155830-602025b48bc0/go.mod h1:dJqO3luvt+zLVuDhxOdsRRCuv945F86Nf1BXxFro+nU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= diff --git a/pkg/sdk/bootstrap.go b/pkg/sdk/bootstrap.go index cc00a5920..3c7aa8317 100644 --- a/pkg/sdk/bootstrap.go +++ b/pkg/sdk/bootstrap.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -96,7 +97,7 @@ func (ts *BootstrapConfig) UnmarshalJSON(data []byte) error { return nil } -func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) { +func (sdk mgSDK) AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) { data, err := json.Marshal(cfg) if err != nil { return "", errors.NewSDKError(err) @@ -104,7 +105,7 @@ func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (stri url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint) - headers, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusOK, http.StatusCreated) + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusOK, http.StatusCreated) if sdkerr != nil { return "", sdkerr } @@ -114,14 +115,14 @@ func (sdk mgSDK) AddBootstrap(cfg BootstrapConfig, domainID, token string) (stri return id, nil } -func (sdk mgSDK) Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) { +func (sdk mgSDK) Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) { endpoint := fmt.Sprintf("%s/%s", domainID, configsEndpoint) url, err := sdk.withQueryParams(sdk.bootstrapURL, endpoint, pm) if err != nil { return BootstrapPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { return BootstrapPage{}, sdkerr } @@ -134,7 +135,7 @@ func (sdk mgSDK) Bootstraps(pm PageMetadata, domainID, token string) (BootstrapP return bb, nil } -func (sdk mgSDK) Whitelist(clientID string, state int, domainID, token string) errors.SDKError { +func (sdk mgSDK) Whitelist(ctx context.Context, clientID string, state int, domainID, token string) errors.SDKError { if clientID == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -146,18 +147,18 @@ func (sdk mgSDK) Whitelist(clientID string, state int, domainID, token string) e url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, whitelistEndpoint, clientID) - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusCreated, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusCreated, http.StatusOK) return sdkerr } -func (sdk mgSDK) ViewBootstrap(id, domainID, token string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, errors.SDKError) { if id == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) - _, body, err := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if err != nil { return BootstrapConfig{}, err } @@ -170,7 +171,7 @@ func (sdk mgSDK) ViewBootstrap(id, domainID, token string) (BootstrapConfig, err return bc, nil } -func (sdk mgSDK) UpdateBootstrap(cfg BootstrapConfig, domainID, token string) errors.SDKError { +func (sdk mgSDK) UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) errors.SDKError { if cfg.ClientID == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -181,12 +182,12 @@ func (sdk mgSDK) UpdateBootstrap(cfg BootstrapConfig, domainID, token string) er return errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) return sdkerr } -func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { if id == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } @@ -202,7 +203,7 @@ func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, t return BootstrapConfig{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodPatch, url, token, data, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) if sdkerr != nil { return BootstrapConfig{}, sdkerr } @@ -215,7 +216,7 @@ func (sdk mgSDK) UpdateBootstrapCerts(id, clientCert, clientKey, ca, domainID, t return bc, nil } -func (sdk mgSDK) UpdateBootstrapConnection(id string, channels []string, domainID, token string) errors.SDKError { +func (sdk mgSDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) errors.SDKError { if id == "" { return errors.NewSDKError(apiutil.ErrMissingID) } @@ -228,27 +229,27 @@ func (sdk mgSDK) UpdateBootstrapConnection(id string, channels []string, domainI return errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(http.MethodPut, url, token, data, nil, http.StatusOK) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) return sdkerr } -func (sdk mgSDK) RemoveBootstrap(id, domainID, token string) errors.SDKError { +func (sdk mgSDK) RemoveBootstrap(ctx context.Context, id, domainID, token string) errors.SDKError { if id == "" { return errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) - _, _, err := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) return err } -func (sdk mgSDK) Bootstrap(externalID, externalKey string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, errors.SDKError) { if externalID == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, bootstrapEndpoint, externalID) - _, body, err := sdk.processRequest(http.MethodGet, url, smqSDK.ClientPrefix+externalKey, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, smqSDK.ClientPrefix+externalKey, nil, nil, http.StatusOK) if err != nil { return BootstrapConfig{}, err } @@ -261,7 +262,7 @@ func (sdk mgSDK) Bootstrap(externalID, externalKey string) (BootstrapConfig, err return bc, nil } -func (sdk mgSDK) BootstrapSecure(externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) { +func (sdk mgSDK) BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) { if externalID == "" { return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) } @@ -272,7 +273,7 @@ func (sdk mgSDK) BootstrapSecure(externalID, externalKey, cryptoKey string) (Boo return BootstrapConfig{}, errors.NewSDKError(err) } - _, body, sdkErr := sdk.processRequest(http.MethodGet, url, smqSDK.ClientPrefix+encExtKey, nil, nil, http.StatusOK) + _, body, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, smqSDK.ClientPrefix+encExtKey, nil, nil, http.StatusOK) if sdkErr != nil { return BootstrapConfig{}, sdkErr } diff --git a/pkg/sdk/bootstrap_test.go b/pkg/sdk/bootstrap_test.go index aafbc03ab..16a13094c 100644 --- a/pkg/sdk/bootstrap_test.go +++ b/pkg/sdk/bootstrap_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -252,7 +253,7 @@ func TestAddBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("Add", mock.Anything, tc.session, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.AddBootstrap(tc.cfg, tc.domainID, tc.token) + resp, err := mgsdk.AddBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, bootstrapConfig.ClientID, resp) @@ -399,7 +400,7 @@ func TestListBootstraps(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("List", mock.Anything, tc.session, mock.Anything, tc.pageMeta.Offset, tc.pageMeta.Limit).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.Bootstraps(tc.pageMeta, tc.domainID, tc.token) + resp, err := mgsdk.Bootstraps(context.Background(), tc.pageMeta, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if err == nil { @@ -504,7 +505,7 @@ func TestWhiteList(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq).Return(tc.svcErr) - err := mgsdk.Whitelist(tc.clientID, tc.state, tc.domainID, tc.token) + err := mgsdk.Whitelist(context.Background(), tc.clientID, tc.state, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq) @@ -625,7 +626,7 @@ func TestViewBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("View", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.ViewBootstrap(tc.id, tc.domainID, tc.token) + resp, err := mgsdk.ViewBootstrap(context.Background(), tc.id, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if err == nil { @@ -788,7 +789,7 @@ func TestUpdateBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticationErr) svcCall := bsvc.On("Update", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr) - err := mgsdk.UpdateBootstrap(tc.cfg, tc.domainID, tc.token) + err := mgsdk.UpdateBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "Update", mock.Anything, tc.session, tc.svcReq) @@ -912,7 +913,7 @@ func TestUpdateBootstrapCerts(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("UpdateCert", mock.Anything, tc.session, tc.id, tc.clientCert, tc.clientKey, tc.caCert).Return(tc.svcResp, tc.svcErr) - resp, err := mgsdk.UpdateBootstrapCerts(tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token) + resp, err := mgsdk.UpdateBootstrapCerts(context.Background(), tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, tc.response, resp) @@ -1015,7 +1016,7 @@ func TestUpdateBootstrapConnection(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels).Return(tc.svcErr) - err := mgsdk.UpdateBootstrapConnection(tc.id, tc.channels, tc.domainID, tc.token) + err := mgsdk.UpdateBootstrapConnection(context.Background(), tc.id, tc.channels, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels) @@ -1102,7 +1103,7 @@ func TestRemoveBootstrap(t *testing.T) { } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := bsvc.On("Remove", mock.Anything, tc.session, tc.id).Return(tc.svcErr) - err := mgsdk.RemoveBootstrap(tc.id, tc.domainID, tc.token) + err := mgsdk.RemoveBootstrap(context.Background(), tc.id, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "Remove", mock.Anything, tc.session, tc.id) @@ -1207,7 +1208,7 @@ func TestBoostrap(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { svcCall := bsvc.On("Bootstrap", mock.Anything, tc.externalKey, tc.externalID, false).Return(tc.svcResp, tc.svcErr) readerCall := reader.On("ReadConfig", tc.svcResp, false).Return(tc.readerResp, tc.readerErr) - resp, err := mgsdk.Bootstrap(tc.externalID, tc.externalKey) + resp, err := mgsdk.Bootstrap(context.Background(), tc.externalID, tc.externalKey) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, tc.response, resp) @@ -1324,7 +1325,7 @@ func TestBootstrapSecure(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { svcCall := bsvc.On("Bootstrap", mock.Anything, mock.Anything, tc.externalID, true).Return(tc.svcResp, tc.svcErr) readerCall := reader.On("ReadConfig", tc.svcResp, true).Return(tc.readerResp, tc.readerErr) - resp, err := mgsdk.BootstrapSecure(tc.externalID, tc.externalKey, tc.cryptoKey) + resp, err := mgsdk.BootstrapSecure(context.Background(), tc.externalID, tc.externalKey, tc.cryptoKey) assert.Equal(t, tc.err, err) if err == nil { assert.Equal(t, sdkBootsrapConfigRes, resp) diff --git a/pkg/sdk/consumers.go b/pkg/sdk/consumers.go index 3019db67e..60c2dcf0d 100644 --- a/pkg/sdk/consumers.go +++ b/pkg/sdk/consumers.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "encoding/json" "fmt" "net/http" @@ -21,7 +22,7 @@ type Subscription struct { Contact string `json:"contact,omitempty"` } -func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, errors.SDKError) { +func (sdk mgSDK) CreateSubscription(ctx context.Context, topic, contact, token string) (string, errors.SDKError) { sub := Subscription{ Topic: topic, Contact: contact, @@ -33,7 +34,7 @@ func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, error url := fmt.Sprintf("%s/%s", sdk.usersURL, subscriptionEndpoint) - headers, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated) + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusCreated) if sdkerr != nil { return "", sdkerr } @@ -43,13 +44,13 @@ func (sdk mgSDK) CreateSubscription(topic, contact, token string) (string, error return id, nil } -func (sdk mgSDK) ListSubscriptions(pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) { +func (sdk mgSDK) ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) { url, err := sdk.withQueryParams(sdk.usersURL, subscriptionEndpoint, pm) if err != nil { return SubscriptionPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { return SubscriptionPage{}, sdkerr } @@ -62,10 +63,10 @@ func (sdk mgSDK) ListSubscriptions(pm PageMetadata, token string) (SubscriptionP return sp, nil } -func (sdk mgSDK) ViewSubscription(id, token string) (Subscription, errors.SDKError) { +func (sdk mgSDK) ViewSubscription(ctx context.Context, id, token string) (Subscription, errors.SDKError) { url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) - _, body, err := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) if err != nil { return Subscription{}, err } @@ -78,10 +79,10 @@ func (sdk mgSDK) ViewSubscription(id, token string) (Subscription, errors.SDKErr return sub, nil } -func (sdk mgSDK) DeleteSubscription(id, token string) errors.SDKError { +func (sdk mgSDK) DeleteSubscription(ctx context.Context, id, token string) errors.SDKError { url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) - _, _, err := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) return err } diff --git a/pkg/sdk/consumers_test.go b/pkg/sdk/consumers_test.go index 6fc997719..bd07c4af4 100644 --- a/pkg/sdk/consumers_test.go +++ b/pkg/sdk/consumers_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -142,7 +143,7 @@ func TestCreateSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("CreateSubscription", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - loc, err := mgsdk.CreateSubscription(tc.subscription.Topic, tc.subscription.Contact, tc.token) + loc, err := mgsdk.CreateSubscription(context.Background(), tc.subscription.Topic, tc.subscription.Contact, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.empty, loc == "") if tc.err == nil { @@ -214,7 +215,7 @@ func TestViewSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("ViewSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.ViewSubscription(tc.subID, tc.token) + resp, err := mgsdk.ViewSubscription(context.Background(), tc.subID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { @@ -374,7 +375,7 @@ func TestListSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("ListSubscriptions", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.ListSubscriptions(tc.pageMeta, tc.token) + resp, err := mgsdk.ListSubscriptions(context.Background(), tc.pageMeta, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { @@ -444,7 +445,7 @@ func TestDeleteSubscription(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { svcCall := nsvc.On("RemoveSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcErr) - err := mgsdk.DeleteSubscription(tc.subID, tc.token) + err := mgsdk.DeleteSubscription(context.Background(), tc.subID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "RemoveSubscription", mock.Anything, tc.token, tc.subID) diff --git a/pkg/sdk/messages.go b/pkg/sdk/messages.go index 3dec4cf50..b6d0e9931 100644 --- a/pkg/sdk/messages.go +++ b/pkg/sdk/messages.go @@ -4,6 +4,7 @@ package sdk import ( + "context" "encoding/json" "fmt" "net/http" @@ -16,7 +17,7 @@ import ( const channelParts = 2 -func (sdk mgSDK) ReadMessages(pm MessagePageMetadata, chanName, domainID, token string) (MessagesPage, errors.SDKError) { +func (sdk mgSDK) ReadMessages(ctx context.Context, pm MessagePageMetadata, chanName, domainID, token string) (MessagesPage, errors.SDKError) { chanNameParts := strings.SplitN(chanName, ".", channelParts) chanID := chanNameParts[0] subtopicPart := "" @@ -32,7 +33,7 @@ func (sdk mgSDK) ReadMessages(pm MessagePageMetadata, chanName, domainID, token header := make(map[string]string) header["Content-Type"] = string(sdk.msgContentType) - _, body, sdkerr := sdk.processRequest(http.MethodGet, msgURL, token, nil, header, http.StatusOK) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, msgURL, token, nil, header, http.StatusOK) if sdkerr != nil { return MessagesPage{}, sdkerr } diff --git a/pkg/sdk/messages_test.go b/pkg/sdk/messages_test.go index c8fce8af4..f52d01ccd 100644 --- a/pkg/sdk/messages_test.go +++ b/pkg/sdk/messages_test.go @@ -4,6 +4,7 @@ package sdk_test import ( + "context" "net/http" "net/http/httptest" "testing" @@ -229,7 +230,7 @@ func TestReadMessages(t *testing.T) { authCall1 := authn.On("Authenticate", mock.Anything, tc.token).Return(smqauthn.Session{UserID: validID}, tc.authnErr) authzCall := channelsGRPCClient.On("Authorize", mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, tc.authzErr) repoCall := repo.On("ReadAll", channelID, mock.Anything).Return(tc.repoRes, tc.repoErr) - response, err := mgsdk.ReadMessages(tc.messagePageMeta, tc.chanName, tc.domainID, tc.token) + response, err := mgsdk.ReadMessages(context.Background(), tc.messagePageMeta, tc.chanName, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, response) if tc.err == nil { diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index d67ff4203..91eb56458 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -92,8 +92,8 @@ func (_c *SDK_AcceptInvitation_Call) RunAndReturn(run func(ctx context.Context, } // AddBootstrap provides a mock function for the type SDK -func (_mock *SDK) AddBootstrap(cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError) { - ret := _mock.Called(cfg, domainID, token) +func (_mock *SDK) AddBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) if len(ret) == 0 { panic("no return value specified for AddBootstrap") @@ -101,16 +101,16 @@ func (_mock *SDK) AddBootstrap(cfg sdk.BootstrapConfig, domainID string, token s var r0 string var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) (string, errors.SDKError)); ok { - return returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) string); ok { - r0 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) string); ok { + r0 = returnFunc(ctx, cfg, domainID, token) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(sdk.BootstrapConfig, string, string) errors.SDKError); ok { - r1 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -125,16 +125,17 @@ type SDK_AddBootstrap_Call struct { } // AddBootstrap is a helper method to define mock.On call +// - ctx // - cfg // - domainID // - token -func (_e *SDK_Expecter) AddBootstrap(cfg interface{}, domainID interface{}, token interface{}) *SDK_AddBootstrap_Call { - return &SDK_AddBootstrap_Call{Call: _e.mock.On("AddBootstrap", cfg, domainID, token)} +func (_e *SDK_Expecter) AddBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_AddBootstrap_Call { + return &SDK_AddBootstrap_Call{Call: _e.mock.On("AddBootstrap", ctx, cfg, domainID, token)} } -func (_c *SDK_AddBootstrap_Call) Run(run func(cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_AddBootstrap_Call { +func (_c *SDK_AddBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_AddBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.BootstrapConfig), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.BootstrapConfig), args[2].(string), args[3].(string)) }) return _c } @@ -144,7 +145,7 @@ func (_c *SDK_AddBootstrap_Call) Return(s string, sDKError errors.SDKError) *SDK return _c } -func (_c *SDK_AddBootstrap_Call) RunAndReturn(run func(cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError)) *SDK_AddBootstrap_Call { +func (_c *SDK_AddBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError)) *SDK_AddBootstrap_Call { _c.Call.Return(run) return _c } @@ -756,8 +757,8 @@ func (_c *SDK_AvailableGroupRoleActions_Call) RunAndReturn(run func(ctx context. } // Bootstrap provides a mock function for the type SDK -func (_mock *SDK) Bootstrap(externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(externalID, externalKey) +func (_mock *SDK) Bootstrap(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey) if len(ret) == 0 { panic("no return value specified for Bootstrap") @@ -765,16 +766,16 @@ func (_mock *SDK) Bootstrap(externalID string, externalKey string) (sdk.Bootstra var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey) } - if returnFunc, ok := ret.Get(0).(func(string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { - r1 = returnFunc(externalID, externalKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -789,15 +790,16 @@ type SDK_Bootstrap_Call struct { } // Bootstrap is a helper method to define mock.On call +// - ctx // - externalID // - externalKey -func (_e *SDK_Expecter) Bootstrap(externalID interface{}, externalKey interface{}) *SDK_Bootstrap_Call { - return &SDK_Bootstrap_Call{Call: _e.mock.On("Bootstrap", externalID, externalKey)} +func (_e *SDK_Expecter) Bootstrap(ctx interface{}, externalID interface{}, externalKey interface{}) *SDK_Bootstrap_Call { + return &SDK_Bootstrap_Call{Call: _e.mock.On("Bootstrap", ctx, externalID, externalKey)} } -func (_c *SDK_Bootstrap_Call) Run(run func(externalID string, externalKey string)) *SDK_Bootstrap_Call { +func (_c *SDK_Bootstrap_Call) Run(run func(ctx context.Context, externalID string, externalKey string)) *SDK_Bootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -807,14 +809,14 @@ func (_c *SDK_Bootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKErr return _c } -func (_c *SDK_Bootstrap_Call) RunAndReturn(run func(externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_Bootstrap_Call { +func (_c *SDK_Bootstrap_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_Bootstrap_Call { _c.Call.Return(run) return _c } // BootstrapSecure provides a mock function for the type SDK -func (_mock *SDK) BootstrapSecure(externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(externalID, externalKey, cryptoKey) +func (_mock *SDK) BootstrapSecure(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey, cryptoKey) if len(ret) == 0 { panic("no return value specified for BootstrapSecure") @@ -822,16 +824,16 @@ func (_mock *SDK) BootstrapSecure(externalID string, externalKey string, cryptoK var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey, cryptoKey) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey, cryptoKey) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(externalID, externalKey, cryptoKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey, cryptoKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -846,16 +848,17 @@ type SDK_BootstrapSecure_Call struct { } // BootstrapSecure is a helper method to define mock.On call +// - ctx // - externalID // - externalKey // - cryptoKey -func (_e *SDK_Expecter) BootstrapSecure(externalID interface{}, externalKey interface{}, cryptoKey interface{}) *SDK_BootstrapSecure_Call { - return &SDK_BootstrapSecure_Call{Call: _e.mock.On("BootstrapSecure", externalID, externalKey, cryptoKey)} +func (_e *SDK_Expecter) BootstrapSecure(ctx interface{}, externalID interface{}, externalKey interface{}, cryptoKey interface{}) *SDK_BootstrapSecure_Call { + return &SDK_BootstrapSecure_Call{Call: _e.mock.On("BootstrapSecure", ctx, externalID, externalKey, cryptoKey)} } -func (_c *SDK_BootstrapSecure_Call) Run(run func(externalID string, externalKey string, cryptoKey string)) *SDK_BootstrapSecure_Call { +func (_c *SDK_BootstrapSecure_Call) Run(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string)) *SDK_BootstrapSecure_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -865,14 +868,14 @@ func (_c *SDK_BootstrapSecure_Call) Return(bootstrapConfig sdk.BootstrapConfig, return _c } -func (_c *SDK_BootstrapSecure_Call) RunAndReturn(run func(externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_BootstrapSecure_Call { +func (_c *SDK_BootstrapSecure_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_BootstrapSecure_Call { _c.Call.Return(run) return _c } // Bootstraps provides a mock function for the type SDK -func (_mock *SDK) Bootstraps(pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) { - ret := _mock.Called(pm, domainID, token) +func (_mock *SDK) Bootstraps(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) if len(ret) == 0 { panic("no return value specified for Bootstraps") @@ -880,16 +883,16 @@ func (_mock *SDK) Bootstraps(pm sdk.PageMetadata, domainID string, token string) var r0 sdk.BootstrapPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok { - return returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string, string) sdk.BootstrapPage); ok { - r0 = returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.BootstrapPage); ok { + r0 = returnFunc(ctx, pm, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.PageMetadata, string, string) errors.SDKError); ok { - r1 = returnFunc(pm, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -904,16 +907,17 @@ type SDK_Bootstraps_Call struct { } // Bootstraps is a helper method to define mock.On call +// - ctx // - pm // - domainID // - token -func (_e *SDK_Expecter) Bootstraps(pm interface{}, domainID interface{}, token interface{}) *SDK_Bootstraps_Call { - return &SDK_Bootstraps_Call{Call: _e.mock.On("Bootstraps", pm, domainID, token)} +func (_e *SDK_Expecter) Bootstraps(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_Bootstraps_Call { + return &SDK_Bootstraps_Call{Call: _e.mock.On("Bootstraps", ctx, pm, domainID, token)} } -func (_c *SDK_Bootstraps_Call) Run(run func(pm sdk.PageMetadata, domainID string, token string)) *SDK_Bootstraps_Call { +func (_c *SDK_Bootstraps_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_Bootstraps_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.PageMetadata), args[2].(string), args[3].(string)) }) return _c } @@ -923,7 +927,7 @@ func (_c *SDK_Bootstraps_Call) Return(bootstrapPage sdk.BootstrapPage, sDKError return _c } -func (_c *SDK_Bootstraps_Call) RunAndReturn(run func(pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError)) *SDK_Bootstraps_Call { +func (_c *SDK_Bootstraps_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError)) *SDK_Bootstraps_Call { _c.Call.Return(run) return _c } @@ -2106,8 +2110,8 @@ func (_c *SDK_CreateGroupRole_Call) RunAndReturn(run func(ctx context.Context, i } // CreateSubscription provides a mock function for the type SDK -func (_mock *SDK) CreateSubscription(topic string, contact string, token string) (string, errors.SDKError) { - ret := _mock.Called(topic, contact, token) +func (_mock *SDK) CreateSubscription(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, topic, contact, token) if len(ret) == 0 { panic("no return value specified for CreateSubscription") @@ -2115,16 +2119,16 @@ func (_mock *SDK) CreateSubscription(topic string, contact string, token string) var r0 string var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (string, errors.SDKError)); ok { - return returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, topic, contact, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) string); ok { - r0 = returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) string); ok { + r0 = returnFunc(ctx, topic, contact, token) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(topic, contact, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, topic, contact, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -2139,16 +2143,17 @@ type SDK_CreateSubscription_Call struct { } // CreateSubscription is a helper method to define mock.On call +// - ctx // - topic // - contact // - token -func (_e *SDK_Expecter) CreateSubscription(topic interface{}, contact interface{}, token interface{}) *SDK_CreateSubscription_Call { - return &SDK_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", topic, contact, token)} +func (_e *SDK_Expecter) CreateSubscription(ctx interface{}, topic interface{}, contact interface{}, token interface{}) *SDK_CreateSubscription_Call { + return &SDK_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", ctx, topic, contact, token)} } -func (_c *SDK_CreateSubscription_Call) Run(run func(topic string, contact string, token string)) *SDK_CreateSubscription_Call { +func (_c *SDK_CreateSubscription_Call) Run(run func(ctx context.Context, topic string, contact string, token string)) *SDK_CreateSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -2158,7 +2163,7 @@ func (_c *SDK_CreateSubscription_Call) Return(s string, sDKError errors.SDKError return _c } -func (_c *SDK_CreateSubscription_Call) RunAndReturn(run func(topic string, contact string, token string) (string, errors.SDKError)) *SDK_CreateSubscription_Call { +func (_c *SDK_CreateSubscription_Call) RunAndReturn(run func(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError)) *SDK_CreateSubscription_Call { _c.Call.Return(run) return _c } @@ -2629,16 +2634,16 @@ func (_c *SDK_DeleteInvitation_Call) RunAndReturn(run func(ctx context.Context, } // DeleteSubscription provides a mock function for the type SDK -func (_mock *SDK) DeleteSubscription(id string, token string) errors.SDKError { - ret := _mock.Called(id, token) +func (_mock *SDK) DeleteSubscription(ctx context.Context, id string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, token) if len(ret) == 0 { panic("no return value specified for DeleteSubscription") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) errors.SDKError); ok { - r0 = returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -2653,15 +2658,16 @@ type SDK_DeleteSubscription_Call struct { } // DeleteSubscription is a helper method to define mock.On call +// - ctx // - id // - token -func (_e *SDK_Expecter) DeleteSubscription(id interface{}, token interface{}) *SDK_DeleteSubscription_Call { - return &SDK_DeleteSubscription_Call{Call: _e.mock.On("DeleteSubscription", id, token)} +func (_e *SDK_Expecter) DeleteSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_DeleteSubscription_Call { + return &SDK_DeleteSubscription_Call{Call: _e.mock.On("DeleteSubscription", ctx, id, token)} } -func (_c *SDK_DeleteSubscription_Call) Run(run func(id string, token string)) *SDK_DeleteSubscription_Call { +func (_c *SDK_DeleteSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_DeleteSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -2671,7 +2677,7 @@ func (_c *SDK_DeleteSubscription_Call) Return(sDKError errors.SDKError) *SDK_Del return _c } -func (_c *SDK_DeleteSubscription_Call) RunAndReturn(run func(id string, token string) errors.SDKError) *SDK_DeleteSubscription_Call { +func (_c *SDK_DeleteSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) errors.SDKError) *SDK_DeleteSubscription_Call { _c.Call.Return(run) return _c } @@ -4748,8 +4754,8 @@ func (_c *SDK_ListGroupMembers_Call) RunAndReturn(run func(ctx context.Context, } // ListSubscriptions provides a mock function for the type SDK -func (_mock *SDK) ListSubscriptions(pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError) { - ret := _mock.Called(pm, token) +func (_mock *SDK) ListSubscriptions(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, token) if len(ret) == 0 { panic("no return value specified for ListSubscriptions") @@ -4757,16 +4763,16 @@ func (_mock *SDK) ListSubscriptions(pm sdk.PageMetadata, token string) (sdk.Subs var r0 sdk.SubscriptionPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.SubscriptionPage, errors.SDKError)); ok { - return returnFunc(pm, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) (sdk.SubscriptionPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.SubscriptionPage); ok { - r0 = returnFunc(pm, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) sdk.SubscriptionPage); ok { + r0 = returnFunc(ctx, pm, token) } else { r0 = ret.Get(0).(sdk.SubscriptionPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { - r1 = returnFunc(pm, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -4781,15 +4787,16 @@ type SDK_ListSubscriptions_Call struct { } // ListSubscriptions is a helper method to define mock.On call +// - ctx // - pm // - token -func (_e *SDK_Expecter) ListSubscriptions(pm interface{}, token interface{}) *SDK_ListSubscriptions_Call { - return &SDK_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", pm, token)} +func (_e *SDK_Expecter) ListSubscriptions(ctx interface{}, pm interface{}, token interface{}) *SDK_ListSubscriptions_Call { + return &SDK_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", ctx, pm, token)} } -func (_c *SDK_ListSubscriptions_Call) Run(run func(pm sdk.PageMetadata, token string)) *SDK_ListSubscriptions_Call { +func (_c *SDK_ListSubscriptions_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, token string)) *SDK_ListSubscriptions_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string)) + run(args[0].(context.Context), args[1].(sdk.PageMetadata), args[2].(string)) }) return _c } @@ -4799,14 +4806,14 @@ func (_c *SDK_ListSubscriptions_Call) Return(subscriptionPage sdk.SubscriptionPa return _c } -func (_c *SDK_ListSubscriptions_Call) RunAndReturn(run func(pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError)) *SDK_ListSubscriptions_Call { +func (_c *SDK_ListSubscriptions_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError)) *SDK_ListSubscriptions_Call { _c.Call.Return(run) return _c } // ReadMessages provides a mock function for the type SDK -func (_mock *SDK) ReadMessages(pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError) { - ret := _mock.Called(pm, chanID, domainID, token) +func (_mock *SDK) ReadMessages(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, chanID, domainID, token) if len(ret) == 0 { panic("no return value specified for ReadMessages") @@ -4814,16 +4821,16 @@ func (_mock *SDK) ReadMessages(pm sdk.MessagePageMetadata, chanID string, domain var r0 sdk.MessagesPage var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.MessagePageMetadata, string, string, string) (sdk.MessagesPage, errors.SDKError)); ok { - return returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) (sdk.MessagesPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, chanID, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(sdk.MessagePageMetadata, string, string, string) sdk.MessagesPage); ok { - r0 = returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) sdk.MessagesPage); ok { + r0 = returnFunc(ctx, pm, chanID, domainID, token) } else { r0 = ret.Get(0).(sdk.MessagesPage) } - if returnFunc, ok := ret.Get(1).(func(sdk.MessagePageMetadata, string, string, string) errors.SDKError); ok { - r1 = returnFunc(pm, chanID, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.MessagePageMetadata, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, chanID, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -4838,17 +4845,18 @@ type SDK_ReadMessages_Call struct { } // ReadMessages is a helper method to define mock.On call +// - ctx // - pm // - chanID // - domainID // - token -func (_e *SDK_Expecter) ReadMessages(pm interface{}, chanID interface{}, domainID interface{}, token interface{}) *SDK_ReadMessages_Call { - return &SDK_ReadMessages_Call{Call: _e.mock.On("ReadMessages", pm, chanID, domainID, token)} +func (_e *SDK_Expecter) ReadMessages(ctx interface{}, pm interface{}, chanID interface{}, domainID interface{}, token interface{}) *SDK_ReadMessages_Call { + return &SDK_ReadMessages_Call{Call: _e.mock.On("ReadMessages", ctx, pm, chanID, domainID, token)} } -func (_c *SDK_ReadMessages_Call) Run(run func(pm sdk.MessagePageMetadata, chanID string, domainID string, token string)) *SDK_ReadMessages_Call { +func (_c *SDK_ReadMessages_Call) Run(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string)) *SDK_ReadMessages_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.MessagePageMetadata), args[1].(string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(sdk.MessagePageMetadata), args[2].(string), args[3].(string), args[4].(string)) }) return _c } @@ -4858,7 +4866,7 @@ func (_c *SDK_ReadMessages_Call) Return(messagesPage sdk.MessagesPage, sDKError return _c } -func (_c *SDK_ReadMessages_Call) RunAndReturn(run func(pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError)) *SDK_ReadMessages_Call { +func (_c *SDK_ReadMessages_Call) RunAndReturn(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError)) *SDK_ReadMessages_Call { _c.Call.Return(run) return _c } @@ -5322,16 +5330,16 @@ func (_c *SDK_RemoveAllGroupRoleMembers_Call) RunAndReturn(run func(ctx context. } // RemoveBootstrap provides a mock function for the type SDK -func (_mock *SDK) RemoveBootstrap(id string, domainID string, token string) errors.SDKError { - ret := _mock.Called(id, domainID, token) +func (_mock *SDK) RemoveBootstrap(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) if len(ret) == 0 { panic("no return value specified for RemoveBootstrap") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) errors.SDKError); ok { - r0 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -5346,16 +5354,17 @@ type SDK_RemoveBootstrap_Call struct { } // RemoveBootstrap is a helper method to define mock.On call +// - ctx // - id // - domainID // - token -func (_e *SDK_Expecter) RemoveBootstrap(id interface{}, domainID interface{}, token interface{}) *SDK_RemoveBootstrap_Call { - return &SDK_RemoveBootstrap_Call{Call: _e.mock.On("RemoveBootstrap", id, domainID, token)} +func (_e *SDK_Expecter) RemoveBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_RemoveBootstrap_Call { + return &SDK_RemoveBootstrap_Call{Call: _e.mock.On("RemoveBootstrap", ctx, id, domainID, token)} } -func (_c *SDK_RemoveBootstrap_Call) Run(run func(id string, domainID string, token string)) *SDK_RemoveBootstrap_Call { +func (_c *SDK_RemoveBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_RemoveBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -5365,7 +5374,7 @@ func (_c *SDK_RemoveBootstrap_Call) Return(sDKError errors.SDKError) *SDK_Remove return _c } -func (_c *SDK_RemoveBootstrap_Call) RunAndReturn(run func(id string, domainID string, token string) errors.SDKError) *SDK_RemoveBootstrap_Call { +func (_c *SDK_RemoveBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_RemoveBootstrap_Call { _c.Call.Return(run) return _c } @@ -6147,8 +6156,8 @@ func (_c *SDK_SendInvitation_Call) RunAndReturn(run func(ctx context.Context, in } // SendMessage provides a mock function for the type SDK -func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string, secret string, msg string) errors.SDKError { - ret := _mock.Called(ctx, domainID, topic, secret, msg) +func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string, msg string, secret string) errors.SDKError { + ret := _mock.Called(ctx, domainID, topic, msg, secret) if len(ret) == 0 { panic("no return value specified for SendMessage") @@ -6156,7 +6165,7 @@ func (_mock *SDK) SendMessage(ctx context.Context, domainID string, topic string var r0 errors.SDKError if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) errors.SDKError); ok { - r0 = returnFunc(ctx, domainID, topic, secret, msg) + r0 = returnFunc(ctx, domainID, topic, msg, secret) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6174,13 +6183,13 @@ type SDK_SendMessage_Call struct { // - ctx // - domainID // - topic -// - secret // - msg -func (_e *SDK_Expecter) SendMessage(ctx interface{}, domainID interface{}, topic interface{}, secret interface{}, msg interface{}) *SDK_SendMessage_Call { - return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, domainID, topic, secret, msg)} +// - secret +func (_e *SDK_Expecter) SendMessage(ctx interface{}, domainID interface{}, topic interface{}, msg interface{}, secret interface{}) *SDK_SendMessage_Call { + return &SDK_SendMessage_Call{Call: _e.mock.On("SendMessage", ctx, domainID, topic, msg, secret)} } -func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, domainID string, topic string, secret string, msg string)) *SDK_SendMessage_Call { +func (_c *SDK_SendMessage_Call) Run(run func(ctx context.Context, domainID string, topic string, msg string, secret string)) *SDK_SendMessage_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) }) @@ -6192,7 +6201,7 @@ func (_c *SDK_SendMessage_Call) Return(sDKError errors.SDKError) *SDK_SendMessag return _c } -func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domainID string, topic string, secret string, msg string) errors.SDKError) *SDK_SendMessage_Call { +func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domainID string, topic string, msg string, secret string) errors.SDKError) *SDK_SendMessage_Call { _c.Call.Return(run) return _c } @@ -6398,16 +6407,16 @@ func (_c *SDK_SetGroupParent_Call) RunAndReturn(run func(ctx context.Context, id } // UpdateBootstrap provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrap(cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError { - ret := _mock.Called(cfg, domainID, token) +func (_mock *SDK) UpdateBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, cfg, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrap") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(sdk.BootstrapConfig, string, string) errors.SDKError); ok { - r0 = returnFunc(cfg, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, cfg, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6422,16 +6431,17 @@ type SDK_UpdateBootstrap_Call struct { } // UpdateBootstrap is a helper method to define mock.On call +// - ctx // - cfg // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrap(cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrap_Call { - return &SDK_UpdateBootstrap_Call{Call: _e.mock.On("UpdateBootstrap", cfg, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrap_Call { + return &SDK_UpdateBootstrap_Call{Call: _e.mock.On("UpdateBootstrap", ctx, cfg, domainID, token)} } -func (_c *SDK_UpdateBootstrap_Call) Run(run func(cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_UpdateBootstrap_Call { +func (_c *SDK_UpdateBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_UpdateBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.BootstrapConfig), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(sdk.BootstrapConfig), args[2].(string), args[3].(string)) }) return _c } @@ -6441,14 +6451,14 @@ func (_c *SDK_UpdateBootstrap_Call) Return(sDKError errors.SDKError) *SDK_Update return _c } -func (_c *SDK_UpdateBootstrap_Call) RunAndReturn(run func(cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrap_Call { +func (_c *SDK_UpdateBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrap_Call { _c.Call.Return(run) return _c } // UpdateBootstrapCerts provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrapCerts(id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(id, clientCert, clientKey, ca, domainID, token) +func (_mock *SDK) UpdateBootstrapCerts(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, clientCert, clientKey, ca, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrapCerts") @@ -6456,16 +6466,16 @@ func (_mock *SDK) UpdateBootstrapCerts(id string, clientCert string, clientKey s var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string, string, string) errors.SDKError); ok { - r1 = returnFunc(id, clientCert, clientKey, ca, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -6480,19 +6490,20 @@ type SDK_UpdateBootstrapCerts_Call struct { } // UpdateBootstrapCerts is a helper method to define mock.On call +// - ctx // - id // - clientCert // - clientKey // - ca // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrapCerts(id interface{}, clientCert interface{}, clientKey interface{}, ca interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapCerts_Call { - return &SDK_UpdateBootstrapCerts_Call{Call: _e.mock.On("UpdateBootstrapCerts", id, clientCert, clientKey, ca, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrapCerts(ctx interface{}, id interface{}, clientCert interface{}, clientKey interface{}, ca interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapCerts_Call { + return &SDK_UpdateBootstrapCerts_Call{Call: _e.mock.On("UpdateBootstrapCerts", ctx, id, clientCert, clientKey, ca, domainID, token)} } -func (_c *SDK_UpdateBootstrapCerts_Call) Run(run func(id string, clientCert string, clientKey string, ca string, domainID string, token string)) *SDK_UpdateBootstrapCerts_Call { +func (_c *SDK_UpdateBootstrapCerts_Call) Run(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string)) *SDK_UpdateBootstrapCerts_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string), args[6].(string)) }) return _c } @@ -6502,22 +6513,22 @@ func (_c *SDK_UpdateBootstrapCerts_Call) Return(bootstrapConfig sdk.BootstrapCon return _c } -func (_c *SDK_UpdateBootstrapCerts_Call) RunAndReturn(run func(id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_UpdateBootstrapCerts_Call { +func (_c *SDK_UpdateBootstrapCerts_Call) RunAndReturn(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_UpdateBootstrapCerts_Call { _c.Call.Return(run) return _c } // UpdateBootstrapConnection provides a mock function for the type SDK -func (_mock *SDK) UpdateBootstrapConnection(id string, channels []string, domainID string, token string) errors.SDKError { - ret := _mock.Called(id, channels, domainID, token) +func (_mock *SDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, channels, domainID, token) if len(ret) == 0 { panic("no return value specified for UpdateBootstrapConnection") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, []string, string, string) errors.SDKError); ok { - r0 = returnFunc(id, channels, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, channels, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -6532,17 +6543,18 @@ type SDK_UpdateBootstrapConnection_Call struct { } // UpdateBootstrapConnection is a helper method to define mock.On call +// - ctx // - id // - channels // - domainID // - token -func (_e *SDK_Expecter) UpdateBootstrapConnection(id interface{}, channels interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapConnection_Call { - return &SDK_UpdateBootstrapConnection_Call{Call: _e.mock.On("UpdateBootstrapConnection", id, channels, domainID, token)} +func (_e *SDK_Expecter) UpdateBootstrapConnection(ctx interface{}, id interface{}, channels interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapConnection_Call { + return &SDK_UpdateBootstrapConnection_Call{Call: _e.mock.On("UpdateBootstrapConnection", ctx, id, channels, domainID, token)} } -func (_c *SDK_UpdateBootstrapConnection_Call) Run(run func(id string, channels []string, domainID string, token string)) *SDK_UpdateBootstrapConnection_Call { +func (_c *SDK_UpdateBootstrapConnection_Call) Run(run func(ctx context.Context, id string, channels []string, domainID string, token string)) *SDK_UpdateBootstrapConnection_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].([]string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].([]string), args[3].(string), args[4].(string)) }) return _c } @@ -6552,7 +6564,7 @@ func (_c *SDK_UpdateBootstrapConnection_Call) Return(sDKError errors.SDKError) * return _c } -func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(id string, channels []string, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapConnection_Call { +func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapConnection_Call { _c.Call.Return(run) return _c } @@ -7733,8 +7745,8 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk0.Pag } // ViewBootstrap provides a mock function for the type SDK -func (_mock *SDK) ViewBootstrap(id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { - ret := _mock.Called(id, domainID, token) +func (_mock *SDK) ViewBootstrap(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) if len(ret) == 0 { panic("no return value specified for ViewBootstrap") @@ -7742,16 +7754,16 @@ func (_mock *SDK) ViewBootstrap(id string, domainID string, token string) (sdk.B var r0 sdk.BootstrapConfig var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { - return returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) } - if returnFunc, ok := ret.Get(0).(func(string, string, string) sdk.BootstrapConfig); ok { - r0 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) } else { r0 = ret.Get(0).(sdk.BootstrapConfig) } - if returnFunc, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = returnFunc(id, domainID, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -7766,16 +7778,17 @@ type SDK_ViewBootstrap_Call struct { } // ViewBootstrap is a helper method to define mock.On call +// - ctx // - id // - domainID // - token -func (_e *SDK_Expecter) ViewBootstrap(id interface{}, domainID interface{}, token interface{}) *SDK_ViewBootstrap_Call { - return &SDK_ViewBootstrap_Call{Call: _e.mock.On("ViewBootstrap", id, domainID, token)} +func (_e *SDK_Expecter) ViewBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewBootstrap_Call { + return &SDK_ViewBootstrap_Call{Call: _e.mock.On("ViewBootstrap", ctx, id, domainID, token)} } -func (_c *SDK_ViewBootstrap_Call) Run(run func(id string, domainID string, token string)) *SDK_ViewBootstrap_Call { +func (_c *SDK_ViewBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewBootstrap_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -7785,7 +7798,7 @@ func (_c *SDK_ViewBootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sD return _c } -func (_c *SDK_ViewBootstrap_Call) RunAndReturn(run func(id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_ViewBootstrap_Call { +func (_c *SDK_ViewBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_ViewBootstrap_Call { _c.Call.Return(run) return _c } @@ -7909,8 +7922,8 @@ func (_c *SDK_ViewCertByClient_Call) RunAndReturn(run func(ctx context.Context, } // ViewSubscription provides a mock function for the type SDK -func (_mock *SDK) ViewSubscription(id string, token string) (sdk.Subscription, errors.SDKError) { - ret := _mock.Called(id, token) +func (_mock *SDK) ViewSubscription(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError) { + ret := _mock.Called(ctx, id, token) if len(ret) == 0 { panic("no return value specified for ViewSubscription") @@ -7918,16 +7931,16 @@ func (_mock *SDK) ViewSubscription(id string, token string) (sdk.Subscription, e var r0 sdk.Subscription var r1 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, string) (sdk.Subscription, errors.SDKError)); ok { - return returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.Subscription, errors.SDKError)); ok { + return returnFunc(ctx, id, token) } - if returnFunc, ok := ret.Get(0).(func(string, string) sdk.Subscription); ok { - r0 = returnFunc(id, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.Subscription); ok { + r0 = returnFunc(ctx, id, token) } else { r0 = ret.Get(0).(sdk.Subscription) } - if returnFunc, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { - r1 = returnFunc(id, token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, token) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -7942,15 +7955,16 @@ type SDK_ViewSubscription_Call struct { } // ViewSubscription is a helper method to define mock.On call +// - ctx // - id // - token -func (_e *SDK_Expecter) ViewSubscription(id interface{}, token interface{}) *SDK_ViewSubscription_Call { - return &SDK_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", id, token)} +func (_e *SDK_Expecter) ViewSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_ViewSubscription_Call { + return &SDK_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", ctx, id, token)} } -func (_c *SDK_ViewSubscription_Call) Run(run func(id string, token string)) *SDK_ViewSubscription_Call { +func (_c *SDK_ViewSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_ViewSubscription_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string)) }) return _c } @@ -7960,22 +7974,22 @@ func (_c *SDK_ViewSubscription_Call) Return(subscription sdk.Subscription, sDKEr return _c } -func (_c *SDK_ViewSubscription_Call) RunAndReturn(run func(id string, token string) (sdk.Subscription, errors.SDKError)) *SDK_ViewSubscription_Call { +func (_c *SDK_ViewSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError)) *SDK_ViewSubscription_Call { _c.Call.Return(run) return _c } // Whitelist provides a mock function for the type SDK -func (_mock *SDK) Whitelist(clientID string, state int, domainID string, token string) errors.SDKError { - ret := _mock.Called(clientID, state, domainID, token) +func (_mock *SDK) Whitelist(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, clientID, state, domainID, token) if len(ret) == 0 { panic("no return value specified for Whitelist") } var r0 errors.SDKError - if returnFunc, ok := ret.Get(0).(func(string, int, string, string) errors.SDKError); ok { - r0 = returnFunc(clientID, state, domainID, token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, clientID, state, domainID, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(errors.SDKError) @@ -7990,17 +8004,18 @@ type SDK_Whitelist_Call struct { } // Whitelist is a helper method to define mock.On call +// - ctx // - clientID // - state // - domainID // - token -func (_e *SDK_Expecter) Whitelist(clientID interface{}, state interface{}, domainID interface{}, token interface{}) *SDK_Whitelist_Call { - return &SDK_Whitelist_Call{Call: _e.mock.On("Whitelist", clientID, state, domainID, token)} +func (_e *SDK_Expecter) Whitelist(ctx interface{}, clientID interface{}, state interface{}, domainID interface{}, token interface{}) *SDK_Whitelist_Call { + return &SDK_Whitelist_Call{Call: _e.mock.On("Whitelist", ctx, clientID, state, domainID, token)} } -func (_c *SDK_Whitelist_Call) Run(run func(clientID string, state int, domainID string, token string)) *SDK_Whitelist_Call { +func (_c *SDK_Whitelist_Call) Run(run func(ctx context.Context, clientID string, state int, domainID string, token string)) *SDK_Whitelist_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(int), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(int), args[3].(string), args[4].(string)) }) return _c } @@ -8010,7 +8025,7 @@ func (_c *SDK_Whitelist_Call) Return(sDKError errors.SDKError) *SDK_Whitelist_Ca return _c } -func (_c *SDK_Whitelist_Call) RunAndReturn(run func(clientID string, state int, domainID string, token string) errors.SDKError) *SDK_Whitelist_Call { +func (_c *SDK_Whitelist_Call) RunAndReturn(run func(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError) *SDK_Whitelist_Call { _c.Call.Return(run) return _c } diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 7806246c7..2069df19c 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -5,6 +5,7 @@ package sdk import ( "bytes" + "context" "crypto/tls" "encoding/json" "fmt" @@ -67,16 +68,16 @@ type SDK interface { // ExternalKey: "externalKey", // Channels: []string{"channel1", "channel2"}, // } - // id, _ := sdk.AddBootstrap(cfg, "domainID", "token") + // id, _ := sdk.AddBootstrap(ctx, cfg, "domainID", "token") // fmt.Println(id) - AddBootstrap(cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) + AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) // View returns Client Config with given ID belonging to the user identified by the given token. // // example: - // bootstrap, _ := sdk.ViewBootstrap("id", "domainID", "token") + // bootstrap, _ := sdk.ViewBootstrap(ctx, "id", "domainID", "token") // fmt.Println(bootstrap) - ViewBootstrap(id, domainID, token string) (BootstrapConfig, errors.SDKError) + ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, errors.SDKError) // Update updates editable fields of the provided Config. // @@ -88,44 +89,44 @@ type SDK interface { // ExternalKey: "externalKey", // Channels: []string{"channel1", "channel2"}, // } - // err := sdk.UpdateBootstrap(cfg, "domainID", "token") + // err := sdk.UpdateBootstrap(ctx, cfg, "domainID", "token") // fmt.Println(err) - UpdateBootstrap(cfg BootstrapConfig, domainID, token string) errors.SDKError + UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) errors.SDKError // Update bootstrap config certificates. // // example: - // err := sdk.UpdateBootstrapCerts("id", "clientCert", "clientKey", "ca", "domainID", "token") + // err := sdk.UpdateBootstrapCerts(ctx, "id", "clientCert", "clientKey", "ca", "domainID", "token") // fmt.Println(err) - UpdateBootstrapCerts(id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, errors.SDKError) + UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, errors.SDKError) // UpdateBootstrapConnection updates connections performs update of the channel list corresponding Client is connected to. // // example: - // err := sdk.UpdateBootstrapConnection("id", []string{"channel1", "channel2"}, "domainID", "token") + // err := sdk.UpdateBootstrapConnection(ctx, "id", []string{"channel1", "channel2"}, "domainID", "token") // fmt.Println(err) - UpdateBootstrapConnection(id string, channels []string, domainID, token string) errors.SDKError + UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) errors.SDKError // Remove removes Config with specified token that belongs to the user identified by the given token. // // example: - // err := sdk.RemoveBootstrap("id", "domainID", "token") + // err := sdk.RemoveBootstrap(ctx, "id", "domainID", "token") // fmt.Println(err) - RemoveBootstrap(id, domainID, token string) errors.SDKError + RemoveBootstrap(ctx context.Context, id, domainID, token string) errors.SDKError // Bootstrap returns Config to the Client with provided external ID using external key. // // example: - // bootstrap, _ := sdk.Bootstrap("externalID", "externalKey") + // bootstrap, _ := sdk.Bootstrap(ctx, "externalID", "externalKey") // fmt.Println(bootstrap) - Bootstrap(externalID, externalKey string) (BootstrapConfig, errors.SDKError) + Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, errors.SDKError) // BootstrapSecure retrieves a configuration with given external ID and encrypted external key. // // example: - // bootstrap, _ := sdk.BootstrapSecure("externalID", "externalKey", "cryptoKey") + // bootstrap, _ := sdk.BootstrapSecure(ctx, "externalID", "externalKey", "cryptoKey") // fmt.Println(bootstrap) - BootstrapSecure(externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) + BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) // Bootstraps retrieves a list of managed configs. // @@ -134,16 +135,16 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // bootstraps, _ := sdk.Bootstraps(pm, "domainID", "token") + // bootstraps, _ := sdk.Bootstraps(ctx, pm, "domainID", "token") // fmt.Println(bootstraps) - Bootstraps(pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) + Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) // Whitelist updates Client state Config with given ID belonging to the user identified by the given token. // // example: - // err := sdk.Whitelist("clientID", 1, "domainID", "token") + // err := sdk.Whitelist(ctx, "clientID", 1, "domainID", "token") // fmt.Println(err) - Whitelist(clientID string, state int, domainID, token string) errors.SDKError + Whitelist(ctx context.Context, clientID string, state int, domainID, token string) errors.SDKError // ReadMessages read messages of specified channel. // @@ -152,16 +153,16 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // msgs, _ := sdk.ReadMessages(pm,"channelID", "domainID", "token") + // msgs, _ := sdk.ReadMessages(ctx, pm,"channelID", "domainID", "token") // fmt.Println(msgs) - ReadMessages(pm MessagePageMetadata, chanID, domainID, token string) (MessagesPage, errors.SDKError) + ReadMessages(ctx context.Context, pm MessagePageMetadata, chanID, domainID, token string) (MessagesPage, errors.SDKError) // CreateSubscription creates a new subscription // // example: - // subscription, _ := sdk.CreateSubscription("topic", "contact", "token") + // subscription, _ := sdk.CreateSubscription(ctx, "topic", "contact", "token") // fmt.Println(subscription) - CreateSubscription(topic, contact, token string) (string, errors.SDKError) + CreateSubscription(ctx context.Context, topic, contact, token string) (string, errors.SDKError) // ListSubscriptions list subscriptions given list parameters. // @@ -170,23 +171,23 @@ type SDK interface { // Offset: 0, // Limit: 10, // } - // subscriptions, _ := sdk.ListSubscriptions(pm, "token") + // subscriptions, _ := sdk.ListSubscriptions(ctx, pm, "token") // fmt.Println(subscriptions) - ListSubscriptions(pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) + ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) // ViewSubscription retrieves a subscription with the provided id. // // example: - // subscription, _ := sdk.ViewSubscription("id", "token") + // subscription, _ := sdk.ViewSubscription(ctx, "id", "token") // fmt.Println(subscription) - ViewSubscription(id, token string) (Subscription, errors.SDKError) + ViewSubscription(ctx context.Context, id, token string) (Subscription, errors.SDKError) // DeleteSubscription removes a subscription with the provided id. // // example: - // err := sdk.DeleteSubscription("id", "token") + // err := sdk.DeleteSubscription(ctx, "id", "token") // fmt.Println(err) - DeleteSubscription(id, token string) errors.SDKError + DeleteSubscription(ctx context.Context, id, token string) errors.SDKError } type mgSDK struct { @@ -257,8 +258,8 @@ func NewSDK(conf Config) SDK { // processRequest creates and send a new HTTP request, and checks for errors in the HTTP response. // It then returns the response headers, the response body, and the associated error(s) (if any). -func (sdk mgSDK) processRequest(method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) { - req, err := http.NewRequest(method, reqUrl, bytes.NewReader(data)) +func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) { + req, err := http.NewRequestWithContext(ctx, method, reqUrl, bytes.NewReader(data)) if err != nil { return make(http.Header), []byte{}, errors.NewSDKError(err) } diff --git a/provision/api/endpoint.go b/provision/api/endpoint.go index 2cdb55010..02c88ad74 100644 --- a/provision/api/endpoint.go +++ b/provision/api/endpoint.go @@ -13,13 +13,13 @@ import ( ) func doProvision(svc provision.Service) endpoint.Endpoint { - return func(_ context.Context, request interface{}) (interface{}, error) { + return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(provisionReq) if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - res, err := svc.Provision(req.domainID, req.token, req.Name, req.ExternalID, req.ExternalKey) + res, err := svc.Provision(ctx, req.domainID, req.token, req.Name, req.ExternalID, req.ExternalKey) if err != nil { return nil, err } @@ -38,13 +38,13 @@ func doProvision(svc provision.Service) endpoint.Endpoint { } func getMapping(svc provision.Service) endpoint.Endpoint { - return func(_ context.Context, request interface{}) (interface{}, error) { + return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(mappingReq) if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - res, err := svc.Mapping(req.token) + res, err := svc.Mapping(ctx, req.token) if err != nil { return nil, err } diff --git a/provision/api/endpoint_test.go b/provision/api/endpoint_test.go index 13f8eae7e..8ec142104 100644 --- a/provision/api/endpoint_test.go +++ b/provision/api/endpoint_test.go @@ -19,6 +19,7 @@ import ( smqlog "github.com/absmach/supermq/logger" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) var ( @@ -140,7 +141,7 @@ func TestProvision(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repocall := svc.On("Provision", validID, tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) + repocall := svc.On("Provision", mock.Anything, validID, tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) req := testRequest{ client: is.Client(), method: http.MethodPost, @@ -205,7 +206,7 @@ func TestMapping(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repocall := svc.On("Mapping", tc.token).Return(map[string]interface{}{}, tc.svcErr) + repocall := svc.On("Mapping", mock.Anything, tc.token).Return(map[string]interface{}{}, tc.svcErr) req := testRequest{ client: is.Client(), method: http.MethodGet, diff --git a/provision/api/logging.go b/provision/api/logging.go index 2156c722f..81b83c303 100644 --- a/provision/api/logging.go +++ b/provision/api/logging.go @@ -6,6 +6,7 @@ package api import ( + "context" "log/slog" "time" @@ -24,7 +25,7 @@ func NewLoggingMiddleware(svc provision.Service, logger *slog.Logger) provision. return &loggingMiddleware{logger, svc} } -func (lm *loggingMiddleware) Provision(domainID, token, name, externalID, externalKey string) (res provision.Result, err error) { +func (lm *loggingMiddleware) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res provision.Result, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -39,10 +40,10 @@ func (lm *loggingMiddleware) Provision(domainID, token, name, externalID, extern lm.logger.Info("Provision completed successfully", args...) }(time.Now()) - return lm.svc.Provision(domainID, token, name, externalID, externalKey) + return lm.svc.Provision(ctx, domainID, token, name, externalID, externalKey) } -func (lm *loggingMiddleware) Cert(domainID, token, clientID, duration string) (cert, key string, err error) { +func (lm *loggingMiddleware) Cert(ctx context.Context, domainID, token, clientID, duration string) (cert, key string, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -57,10 +58,10 @@ func (lm *loggingMiddleware) Cert(domainID, token, clientID, duration string) (c lm.logger.Info("Client certificate created successfully", args...) }(time.Now()) - return lm.svc.Cert(domainID, token, clientID, duration) + return lm.svc.Cert(ctx, domainID, token, clientID, duration) } -func (lm *loggingMiddleware) Mapping(token string) (res map[string]interface{}, err error) { +func (lm *loggingMiddleware) Mapping(ctx context.Context, token string) (res map[string]interface{}, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -73,5 +74,5 @@ func (lm *loggingMiddleware) Mapping(token string) (res map[string]interface{}, lm.logger.Info("Mapping completed successfully", args...) }(time.Now()) - return lm.svc.Mapping(token) + return lm.svc.Mapping(ctx, token) } diff --git a/provision/mocks/service.go b/provision/mocks/service.go index 68e708164..58c6fb6b4 100644 --- a/provision/mocks/service.go +++ b/provision/mocks/service.go @@ -8,6 +8,8 @@ package mocks import ( + "context" + "github.com/absmach/magistrala/provision" mock "github.com/stretchr/testify/mock" ) @@ -40,8 +42,8 @@ func (_m *Service) EXPECT() *Service_Expecter { } // Cert provides a mock function for the type Service -func (_mock *Service) Cert(domainID string, token string, clientID string, duration string) (string, string, error) { - ret := _mock.Called(domainID, token, clientID, duration) +func (_mock *Service) Cert(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error) { + ret := _mock.Called(ctx, domainID, token, clientID, duration) if len(ret) == 0 { panic("no return value specified for Cert") @@ -50,21 +52,21 @@ func (_mock *Service) Cert(domainID string, token string, clientID string, durat var r0 string var r1 string var r2 error - if returnFunc, ok := ret.Get(0).(func(string, string, string, string) (string, string, error)); ok { - return returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) (string, string, error)); ok { + return returnFunc(ctx, domainID, token, clientID, duration) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string) string); ok { - r0 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) string); ok { + r0 = returnFunc(ctx, domainID, token, clientID, duration) } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string) string); ok { - r1 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string) string); ok { + r1 = returnFunc(ctx, domainID, token, clientID, duration) } else { r1 = ret.Get(1).(string) } - if returnFunc, ok := ret.Get(2).(func(string, string, string, string) error); ok { - r2 = returnFunc(domainID, token, clientID, duration) + if returnFunc, ok := ret.Get(2).(func(context.Context, string, string, string, string) error); ok { + r2 = returnFunc(ctx, domainID, token, clientID, duration) } else { r2 = ret.Error(2) } @@ -77,17 +79,18 @@ type Service_Cert_Call struct { } // Cert is a helper method to define mock.On call +// - ctx // - domainID // - token // - clientID // - duration -func (_e *Service_Expecter) Cert(domainID interface{}, token interface{}, clientID interface{}, duration interface{}) *Service_Cert_Call { - return &Service_Cert_Call{Call: _e.mock.On("Cert", domainID, token, clientID, duration)} +func (_e *Service_Expecter) Cert(ctx interface{}, domainID interface{}, token interface{}, clientID interface{}, duration interface{}) *Service_Cert_Call { + return &Service_Cert_Call{Call: _e.mock.On("Cert", ctx, domainID, token, clientID, duration)} } -func (_c *Service_Cert_Call) Run(run func(domainID string, token string, clientID string, duration string)) *Service_Cert_Call { +func (_c *Service_Cert_Call) Run(run func(ctx context.Context, domainID string, token string, clientID string, duration string)) *Service_Cert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) }) return _c } @@ -97,14 +100,14 @@ func (_c *Service_Cert_Call) Return(s string, s1 string, err error) *Service_Cer return _c } -func (_c *Service_Cert_Call) RunAndReturn(run func(domainID string, token string, clientID string, duration string) (string, string, error)) *Service_Cert_Call { +func (_c *Service_Cert_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error)) *Service_Cert_Call { _c.Call.Return(run) return _c } // Mapping provides a mock function for the type Service -func (_mock *Service) Mapping(token string) (map[string]interface{}, error) { - ret := _mock.Called(token) +func (_mock *Service) Mapping(ctx context.Context, token string) (map[string]interface{}, error) { + ret := _mock.Called(ctx, token) if len(ret) == 0 { panic("no return value specified for Mapping") @@ -112,18 +115,18 @@ func (_mock *Service) Mapping(token string) (map[string]interface{}, error) { var r0 map[string]interface{} var r1 error - if returnFunc, ok := ret.Get(0).(func(string) (map[string]interface{}, error)); ok { - return returnFunc(token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (map[string]interface{}, error)); ok { + return returnFunc(ctx, token) } - if returnFunc, ok := ret.Get(0).(func(string) map[string]interface{}); ok { - r0 = returnFunc(token) + if returnFunc, ok := ret.Get(0).(func(context.Context, string) map[string]interface{}); ok { + r0 = returnFunc(ctx, token) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(map[string]interface{}) } } - if returnFunc, ok := ret.Get(1).(func(string) error); ok { - r1 = returnFunc(token) + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, token) } else { r1 = ret.Error(1) } @@ -136,14 +139,15 @@ type Service_Mapping_Call struct { } // Mapping is a helper method to define mock.On call +// - ctx // - token -func (_e *Service_Expecter) Mapping(token interface{}) *Service_Mapping_Call { - return &Service_Mapping_Call{Call: _e.mock.On("Mapping", token)} +func (_e *Service_Expecter) Mapping(ctx interface{}, token interface{}) *Service_Mapping_Call { + return &Service_Mapping_Call{Call: _e.mock.On("Mapping", ctx, token)} } -func (_c *Service_Mapping_Call) Run(run func(token string)) *Service_Mapping_Call { +func (_c *Service_Mapping_Call) Run(run func(ctx context.Context, token string)) *Service_Mapping_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(context.Context), args[1].(string)) }) return _c } @@ -153,14 +157,14 @@ func (_c *Service_Mapping_Call) Return(stringToIfaceVal map[string]interface{}, return _c } -func (_c *Service_Mapping_Call) RunAndReturn(run func(token string) (map[string]interface{}, error)) *Service_Mapping_Call { +func (_c *Service_Mapping_Call) RunAndReturn(run func(ctx context.Context, token string) (map[string]interface{}, error)) *Service_Mapping_Call { _c.Call.Return(run) return _c } // Provision provides a mock function for the type Service -func (_mock *Service) Provision(domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error) { - ret := _mock.Called(domainID, token, name, externalID, externalKey) +func (_mock *Service) Provision(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error) { + ret := _mock.Called(ctx, domainID, token, name, externalID, externalKey) if len(ret) == 0 { panic("no return value specified for Provision") @@ -168,16 +172,16 @@ func (_mock *Service) Provision(domainID string, token string, name string, exte var r0 provision.Result var r1 error - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string) (provision.Result, error)); ok { - return returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (provision.Result, error)); ok { + return returnFunc(ctx, domainID, token, name, externalID, externalKey) } - if returnFunc, ok := ret.Get(0).(func(string, string, string, string, string) provision.Result); ok { - r0 = returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) provision.Result); ok { + r0 = returnFunc(ctx, domainID, token, name, externalID, externalKey) } else { r0 = ret.Get(0).(provision.Result) } - if returnFunc, ok := ret.Get(1).(func(string, string, string, string, string) error); ok { - r1 = returnFunc(domainID, token, name, externalID, externalKey) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok { + r1 = returnFunc(ctx, domainID, token, name, externalID, externalKey) } else { r1 = ret.Error(1) } @@ -190,18 +194,19 @@ type Service_Provision_Call struct { } // Provision is a helper method to define mock.On call +// - ctx // - domainID // - token // - name // - externalID // - externalKey -func (_e *Service_Expecter) Provision(domainID interface{}, token interface{}, name interface{}, externalID interface{}, externalKey interface{}) *Service_Provision_Call { - return &Service_Provision_Call{Call: _e.mock.On("Provision", domainID, token, name, externalID, externalKey)} +func (_e *Service_Expecter) Provision(ctx interface{}, domainID interface{}, token interface{}, name interface{}, externalID interface{}, externalKey interface{}) *Service_Provision_Call { + return &Service_Provision_Call{Call: _e.mock.On("Provision", ctx, domainID, token, name, externalID, externalKey)} } -func (_c *Service_Provision_Call) Run(run func(domainID string, token string, name string, externalID string, externalKey string)) *Service_Provision_Call { +func (_c *Service_Provision_Call) Run(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string)) *Service_Provision_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(string)) }) return _c } @@ -211,7 +216,7 @@ func (_c *Service_Provision_Call) Return(result provision.Result, err error) *Se return _c } -func (_c *Service_Provision_Call) RunAndReturn(run func(domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error)) *Service_Provision_Call { +func (_c *Service_Provision_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error)) *Service_Provision_Call { _c.Call.Return(run) return _c } diff --git a/provision/service.go b/provision/service.go index 405937719..f4628564f 100644 --- a/provision/service.go +++ b/provision/service.go @@ -56,18 +56,18 @@ type Service interface { // - create multiple Channels // - create Bootstrap configuration // - whitelist Client in Bootstrap configuration == connect Client to Channels - Provision(domainID, token, name, externalID, externalKey string) (Result, error) + Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (Result, error) // Mapping returns current configuration used for provision // useful for using in ui to create configuration that matches // one created with Provision method. - Mapping(token string) (map[string]interface{}, error) + Mapping(ctx context.Context, token string) (map[string]interface{}, error) // Certs creates certificate for clients that communicate over mTLS // A duration string is a possibly signed sequence of decimal numbers, // each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". - Cert(domainID, token, clientID, duration string) (string, string, error) + Cert(ctx context.Context, domainID, token, clientID, duration string) (string, string, error) } type provisionService struct { @@ -97,13 +97,13 @@ func New(cfg Config, mgsdk sdk.SDK, logger *slog.Logger) Service { } // Mapping retrieves current configuration. -func (ps *provisionService) Mapping(token string) (map[string]interface{}, error) { +func (ps *provisionService) Mapping(ctx context.Context, token string) (map[string]interface{}, error) { pm := smqSDK.PageMetadata{ Offset: uint64(offset), Limit: uint64(limit), } - if _, err := ps.sdk.Users(context.Background(), pm, token); err != nil { + if _, err := ps.sdk.Users(ctx, pm, token); err != nil { return map[string]interface{}{}, errors.Wrap(ErrUnauthorized, err) } @@ -112,12 +112,12 @@ func (ps *provisionService) Mapping(token string) (map[string]interface{}, error // Provision is provision method for creating setup according to // provision layout specified in config.toml. -func (ps *provisionService) Provision(domainID, token, name, externalID, externalKey string) (res Result, err error) { +func (ps *provisionService) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res Result, err error) { var channels []smqSDK.Channel var clients []smqSDK.Client - defer ps.recover(&err, &clients, &channels, &domainID, &token) + defer ps.recover(ctx, &err, &clients, &channels, &domainID, &token) - token, err = ps.createTokenIfEmpty(token) + token, err = ps.createTokenIfEmpty(ctx, token) if err != nil { return res, errors.Wrap(ErrFailedToCreateToken, err) } @@ -142,14 +142,14 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa name = c.Name } cli.Name = name - cli, err := ps.sdk.CreateClient(context.Background(), cli, domainID, token) + cli, err := ps.sdk.CreateClient(ctx, cli, domainID, token) if err != nil { res.Error = err.Error() return res, errors.Wrap(ErrFailedClientCreation, err) } // Get newly created client (in order to get the key). - cli, err = ps.sdk.Client(context.Background(), cli.ID, domainID, token) + cli, err = ps.sdk.Client(ctx, cli.ID, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("client id: %s", cli.ID)) return res, errors.Wrap(ErrFailedClientRetrieval, e) @@ -162,11 +162,11 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa Name: name + "_" + channel.Name, Metadata: smqSDK.Metadata(channel.Metadata), } - ch, err := ps.sdk.CreateChannel(context.Background(), ch, domainID, token) + ch, err := ps.sdk.CreateChannel(ctx, ch, domainID, token) if err != nil { return res, errors.Wrap(ErrFailedChannelCreation, err) } - ch, err = ps.sdk.Channel(context.Background(), ch.ID, domainID, token) + ch, err = ps.sdk.Channel(ctx, ch.ID, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("channel id: %s", ch.ID)) return res, errors.Wrap(ErrFailedChannelRetrieval, e) @@ -206,12 +206,12 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa ClientKey: cert.Key, Content: string(content), } - bsid, err := ps.sdk.AddBootstrap(bsReq, domainID, token) + bsid, err := ps.sdk.AddBootstrap(ctx, bsReq, domainID, token) if err != nil { return Result{}, errors.Wrap(ErrFailedBootstrap, err) } - bsConfig, err = ps.sdk.ViewBootstrap(bsid, domainID, token) + bsConfig, err = ps.sdk.ViewBootstrap(ctx, bsid, domainID, token) if err != nil { return Result{}, errors.Wrap(ErrFailedBootstrapValidate, err) } @@ -220,12 +220,12 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa if ps.conf.Bootstrap.X509Provision { var cert smqSDK.Cert - cert, err = ps.sdk.IssueCert(context.Background(), c.ID, ps.conf.Cert.TTL, domainID, token) + cert, err = ps.sdk.IssueCert(ctx, c.ID, ps.conf.Cert.TTL, domainID, token) if err != nil { e := errors.Wrap(err, fmt.Errorf("client id: %s", c.ID)) return res, errors.Wrap(ErrFailedCertCreation, e) } - cert, err := ps.sdk.ViewCert(context.Background(), cert.SerialNumber, domainID, token) + cert, err := ps.sdk.ViewCert(ctx, cert.SerialNumber, domainID, token) if err != nil { return res, errors.Wrap(ErrFailedCertView, err) } @@ -235,14 +235,14 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa res.CACert = "" if needsBootstrap(c) { - if _, err = ps.sdk.UpdateBootstrapCerts(bsConfig.ClientID, cert.Certificate, cert.Key, "", domainID, token); err != nil { + if _, err = ps.sdk.UpdateBootstrapCerts(ctx, bsConfig.ClientID, cert.Certificate, cert.Key, "", domainID, token); err != nil { return Result{}, errors.Wrap(ErrFailedCertCreation, err) } } } if ps.conf.Bootstrap.AutoWhiteList { - if err := ps.sdk.Whitelist(c.ID, Active, domainID, token); err != nil { + if err := ps.sdk.Whitelist(ctx, c.ID, Active, domainID, token); err != nil { res.Error = err.Error() return res, ErrClientUpdate } @@ -250,18 +250,17 @@ func (ps *provisionService) Provision(domainID, token, name, externalID, externa } } - if err = ps.updateGateway(domainID, token, bsConfig, channels); err != nil { + if err = ps.updateGateway(ctx, domainID, token, bsConfig, channels); err != nil { return res, err } return res, nil } -func (ps *provisionService) Cert(domainID, token, clientID, ttl string) (string, string, error) { - token, err := ps.createTokenIfEmpty(token) +func (ps *provisionService) Cert(ctx context.Context, domainID, token, clientID, ttl string) (string, string, error) { + token, err := ps.createTokenIfEmpty(ctx, token) if err != nil { return "", "", errors.Wrap(ErrFailedToCreateToken, err) } - ctx := context.Background() th, err := ps.sdk.Client(ctx, clientID, domainID, token) if err != nil { @@ -278,7 +277,7 @@ func (ps *provisionService) Cert(domainID, token, clientID, ttl string) (string, return cert.Certificate, cert.Key, err } -func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { +func (ps *provisionService) createTokenIfEmpty(ctx context.Context, token string) (string, error) { if token != "" { return token, nil } @@ -298,7 +297,7 @@ func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { Username: ps.conf.Server.MgUsername, Password: ps.conf.Server.MgPass, } - tkn, err := ps.sdk.CreateToken(context.Background(), u) + tkn, err := ps.sdk.CreateToken(ctx, u) if err != nil { return token, errors.Wrap(ErrFailedToCreateToken, err) } @@ -306,7 +305,7 @@ func (ps *provisionService) createTokenIfEmpty(token string) (string, error) { return tkn.AccessToken, nil } -func (ps *provisionService) updateGateway(domainID, token string, bs sdk.BootstrapConfig, channels []smqSDK.Channel) error { +func (ps *provisionService) updateGateway(ctx context.Context, domainID, token string, bs sdk.BootstrapConfig, channels []smqSDK.Channel) error { var gw Gateway for _, ch := range channels { switch ch.Metadata["type"] { @@ -323,7 +322,7 @@ func (ps *provisionService) updateGateway(domainID, token string, bs sdk.Bootstr gw.CfgID = bs.ClientID gw.Type = gateway - c, sdkerr := ps.sdk.Client(context.Background(), bs.ClientID, domainID, token) + c, sdkerr := ps.sdk.Client(ctx, bs.ClientID, domainID, token) if sdkerr != nil { return errors.Wrap(ErrGatewayUpdate, sdkerr) } @@ -334,7 +333,7 @@ func (ps *provisionService) updateGateway(domainID, token string, bs sdk.Bootstr if err := json.Unmarshal(b, &c.Metadata); err != nil { return errors.Wrap(ErrGatewayUpdate, err) } - if _, err := ps.sdk.UpdateClient(context.Background(), c, domainID, token); err != nil { + if _, err := ps.sdk.UpdateClient(ctx, c, domainID, token); err != nil { return errors.Wrap(ErrGatewayUpdate, err) } return nil @@ -346,18 +345,18 @@ func (ps *provisionService) errLog(err error) { } } -func clean(ps *provisionService, clients []smqSDK.Client, channels []smqSDK.Channel, domainID, token string) { +func clean(ctx context.Context, ps *provisionService, clients []smqSDK.Client, channels []smqSDK.Channel, domainID, token string) { for _, t := range clients { - err := ps.sdk.DeleteClient(context.Background(), t.ID, domainID, token) + err := ps.sdk.DeleteClient(ctx, t.ID, domainID, token) ps.errLog(err) } for _, c := range channels { - err := ps.sdk.DeleteChannel(context.Background(), c.ID, domainID, token) + err := ps.sdk.DeleteChannel(ctx, c.ID, domainID, token) ps.errLog(err) } } -func (ps *provisionService) recover(e *error, ths *[]smqSDK.Client, chs *[]smqSDK.Channel, dm, tkn *string) { +func (ps *provisionService) recover(ctx context.Context, e *error, ths *[]smqSDK.Client, chs *[]smqSDK.Channel, dm, tkn *string) { if e == nil { return } @@ -365,49 +364,49 @@ func (ps *provisionService) recover(e *error, ths *[]smqSDK.Client, chs *[]smqSD if errors.Contains(err, ErrFailedClientRetrieval) || errors.Contains(err, ErrFailedChannelCreation) { for _, c := range clients { - err := ps.sdk.DeleteClient(context.Background(), c.ID, domainID, token) + err := ps.sdk.DeleteClient(ctx, c.ID, domainID, token) ps.errLog(err) } return } if errors.Contains(err, ErrFailedBootstrap) || errors.Contains(err, ErrFailedChannelRetrieval) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) return } if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if needsBootstrap(th) { - ps.errLog(ps.sdk.RemoveBootstrap(th.ID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, th.ID, domainID, token)) } } return } if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if needsBootstrap(th) { - bs, err := ps.sdk.ViewBootstrap(th.ID, domainID, token) + bs, err := ps.sdk.ViewBootstrap(ctx, th.ID, domainID, token) ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) - ps.errLog(ps.sdk.RemoveBootstrap(bs.ClientID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) } } } if errors.Contains(err, ErrClientUpdate) || errors.Contains(err, ErrGatewayUpdate) { - clean(ps, clients, channels, domainID, token) + clean(ctx, ps, clients, channels, domainID, token) for _, th := range clients { if ps.conf.Bootstrap.X509Provision && needsBootstrap(th) { - _, err := ps.sdk.RevokeCert(context.Background(), th.ID, domainID, token) + _, err := ps.sdk.RevokeCert(ctx, th.ID, domainID, token) ps.errLog(err) } if needsBootstrap(th) { - bs, err := ps.sdk.ViewBootstrap(th.ID, domainID, token) + bs, err := ps.sdk.ViewBootstrap(ctx, th.ID, domainID, token) ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) - ps.errLog(ps.sdk.RemoveBootstrap(bs.ClientID, domainID, token)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) } } return diff --git a/provision/service_test.go b/provision/service_test.go index 37a119bfe..a329abeb5 100644 --- a/provision/service_test.go +++ b/provision/service_test.go @@ -4,6 +4,7 @@ package provision_test import ( + "context" "fmt" "testing" @@ -51,8 +52,8 @@ func TestMapping(t *testing.T) { for _, c := range cases { t.Run(c.desc, func(t *testing.T) { pm := smqSDK.PageMetadata{Offset: uint64(0), Limit: uint64(10)} - repocall := mgsdk.On("Users", pm, c.token).Return(smqSDK.UsersPage{}, c.sdkerr) - content, err := svc.Mapping(c.token) + repocall := mgsdk.On("Users", mock.Anything, pm, c.token).Return(smqSDK.UsersPage{}, c.sdkerr) + content, err := svc.Mapping(context.Background(), c.token) assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected error %v, got %v", c.err, err)) assert.Equal(t, c.content, content) repocall.Unset() @@ -215,15 +216,15 @@ func TestCert(t *testing.T) { mgsdk := new(sdkmocks.SDK) svc := provision.New(c.config, mgsdk, smqlog.NewMock()) - mgsdk.On("Client", c.clientID, c.domainID, mock.Anything).Return(smqSDK.Client{ID: c.clientID}, c.sdkClientErr) - mgsdk.On("IssueCert", c.clientID, c.config.Cert.TTL, c.domainID, mock.Anything).Return(smqSDK.Cert{SerialNumber: c.serial}, c.sdkCertErr) - mgsdk.On("ViewCert", c.serial, mock.Anything, mock.Anything).Return(smqSDK.Cert{Certificate: c.cert, Key: c.key}, c.sdkCertErr) + mgsdk.On("Client", mock.Anything, c.clientID, c.domainID, mock.Anything).Return(smqSDK.Client{ID: c.clientID}, c.sdkClientErr) + mgsdk.On("IssueCert", mock.Anything, c.clientID, c.config.Cert.TTL, c.domainID, mock.Anything).Return(smqSDK.Cert{SerialNumber: c.serial}, c.sdkCertErr) + mgsdk.On("ViewCert", mock.Anything, c.serial, mock.Anything, mock.Anything).Return(smqSDK.Cert{Certificate: c.cert, Key: c.key}, c.sdkCertErr) login := smqSDK.Login{ Username: c.config.Server.MgUsername, Password: c.config.Server.MgPass, } - mgsdk.On("CreateToken", login).Return(smqSDK.Token{AccessToken: validToken}, c.sdkTokenErr) - cert, key, err := svc.Cert(c.domainID, c.token, c.clientID, c.ttl) + mgsdk.On("CreateToken", mock.Anything, login).Return(smqSDK.Token{AccessToken: validToken}, c.sdkTokenErr) + cert, key, err := svc.Cert(context.Background(), c.domainID, c.token, c.clientID, c.ttl) assert.Equal(t, c.cert, cert) assert.Equal(t, c.key, key) assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected error %v, got %v", c.err, err)) diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 8bcea0690..225da972f 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -33,3 +33,7 @@ packages: github.com/absmach/magistrala/provision: interfaces: Service: + github.com/absmach/magistrala/alarms: + interfaces: + Service: + Repository: \ No newline at end of file diff --git a/tools/e2e/cmd/main.go b/tools/e2e/cmd/main.go index a1d10e03b..3a6d6715a 100644 --- a/tools/e2e/cmd/main.go +++ b/tools/e2e/cmd/main.go @@ -30,8 +30,8 @@ func main() { "go run tools/e2e/cmd/main.go\n" + "go run tools/e2e/cmd/main.go --host 142.93.118.47\n" + "go run tools/e2e/cmd/main.go --host localhost --num 10 --num_of_messages 100 --prefix e2e", - Run: func(_ *cobra.Command, _ []string) { - e2e.Test(econf) + Run: func(cmd *cobra.Command, _ []string) { + e2e.Test(cmd.Context(), econf) }, } diff --git a/tools/e2e/e2e.go b/tools/e2e/e2e.go index d3622f671..4d2b5e13f 100644 --- a/tools/e2e/e2e.go +++ b/tools/e2e/e2e.go @@ -66,7 +66,7 @@ type Config struct { // - Connect client to channel // - Publish message from HTTP, MQTT, WS and CoAP Adapters. -func Test(conf Config) { +func Test(ctx context.Context, conf Config) { sdkConf := sdk.Config{ UsersURL: fmt.Sprintf("http://%s:%s", conf.Host, usersPort), GroupsURL: fmt.Sprintf("http://%s:%s", conf.Host, groupsPort), @@ -82,51 +82,51 @@ func Test(conf Config) { magenta := color.FgLightMagenta.Render - domainID, token, err := createUser(s, conf) + domainID, token, err := createUser(ctx, s, conf) if err != nil { errExit(fmt.Errorf("unable to create user: %w", err)) } color.Success.Printf("created user with token %s\n", magenta(token)) color.Success.Printf("created domain with ID %s\n", magenta(domainID)) - users, err := createUsers(s, conf, token) + users, err := createUsers(ctx, s, conf, token) if err != nil { errExit(fmt.Errorf("unable to create users: %w", err)) } color.Success.Printf("created users of ids:\n%s\n", magenta(getIDS(users))) - groups, err := createGroups(s, conf, domainID, token) + groups, err := createGroups(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create groups: %w", err)) } color.Success.Printf("created groups of ids:\n%s\n", magenta(getIDS(groups))) - clients, err := createClients(s, conf, domainID, token) + clients, err := createClients(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create clients: %w", err)) } color.Success.Printf("created clients of ids:\n%s\n", magenta(getIDS(clients))) - channels, err := createChannels(s, conf, domainID, token) + channels, err := createChannels(ctx, s, conf, domainID, token) if err != nil { errExit(fmt.Errorf("unable to create channels: %w", err)) } color.Success.Printf("created channels of ids:\n%s\n", magenta(getIDS(channels))) // List users, groups, clients and channels - if err := read(s, conf, domainID, token, users, groups, clients, channels); err != nil { + if err := read(ctx, s, conf, domainID, token, users, groups, clients, channels); err != nil { errExit(fmt.Errorf("unable to read users, groups, clients and channels: %w", err)) } color.Success.Println("viewed users, groups, clients and channels") // Update users, groups, clients and channels - if err := update(s, domainID, token, users, groups, clients, channels); err != nil { + if err := update(ctx, s, domainID, token, users, groups, clients, channels); err != nil { errExit(fmt.Errorf("unable to update users, groups, clients and channels: %w", err)) } color.Success.Println("updated users, groups, clients and channels") // Send messages to channels - if err := messaging(s, conf, domainID, token, clients, channels); err != nil { + if err := messaging(ctx, s, conf, domainID, token, clients, channels); err != nil { errExit(fmt.Errorf("unable to send messages to channels: %w", err)) } color.Success.Println("sent messages to channels") @@ -137,7 +137,7 @@ func errExit(err error) { os.Exit(1) } -func createUser(s sdk.SDK, conf Config) (string, string, error) { +func createUser(ctx context.Context, s sdk.SDK, conf Config) (string, string, error) { user := sdk.User{ FirstName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), LastName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), @@ -149,8 +149,8 @@ func createUser(s sdk.SDK, conf Config) (string, string, error) { Status: sdk.EnabledStatus, Role: "admin", } - ctx := context.Background() - if _, err := s.CreateUser(context.Background(), user, ""); err != nil { + + if _, err := s.CreateUser(ctx, user, ""); err != nil { return "", "", fmt.Errorf("unable to create user: %w", err) } @@ -187,10 +187,9 @@ func createUser(s sdk.SDK, conf Config) (string, string, error) { return domain.ID, token.AccessToken, nil } -func createUsers(s sdk.SDK, conf Config, token string) ([]sdk.User, error) { +func createUsers(ctx context.Context, s sdk.SDK, conf Config, token string) ([]sdk.User, error) { var err error users := []sdk.User{} - ctx := context.Background() for i := uint64(0); i < conf.Num; i++ { user := sdk.User{ @@ -214,10 +213,9 @@ func createUsers(s sdk.SDK, conf Config, token string) ([]sdk.User, error) { return users, nil } -func createGroups(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, error) { +func createGroups(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, error) { var err error groups := []sdk.Group{} - ctx := context.Background() for i := uint64(0); i < conf.Num; i++ { group := sdk.Group{ @@ -235,10 +233,9 @@ func createGroups(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, return groups, nil } -func createClientsInBatch(s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Client, error) { +func createClientsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Client, error) { var err error clients := make([]sdk.Client, num) - ctx := context.Background() for i := uint64(0); i < num; i++ { clients[i] = sdk.Client{ @@ -254,25 +251,25 @@ func createClientsInBatch(s sdk.SDK, conf Config, domainID, token string, num ui return clients, nil } -func createClients(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client, error) { +func createClients(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client, error) { clients := []sdk.Client{} if conf.Num > batchSize { batches := int(conf.Num) / batchSize for i := 0; i < batches; i++ { - ths, err := createClientsInBatch(s, conf, domainID, token, batchSize) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, batchSize) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } clients = append(clients, ths...) } - ths, err := createClientsInBatch(s, conf, domainID, token, conf.Num%uint64(batchSize)) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } clients = append(clients, ths...) } else { - ths, err := createClientsInBatch(s, conf, domainID, token, conf.Num) + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num) if err != nil { return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) } @@ -282,10 +279,9 @@ func createClients(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client return clients, nil } -func createChannelsInBatch(s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Channel, error) { +func createChannelsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Channel, error) { var err error channels := make([]sdk.Channel, num) - ctx := context.Background() for i := uint64(0); i < num; i++ { channels[i] = sdk.Channel{ @@ -300,25 +296,25 @@ func createChannelsInBatch(s sdk.SDK, conf Config, domainID, token string, num u return channels, nil } -func createChannels(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Channel, error) { +func createChannels(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Channel, error) { channels := []sdk.Channel{} if conf.Num > batchSize { batches := int(conf.Num) / batchSize for i := 0; i < batches; i++ { - chs, err := createChannelsInBatch(s, conf, token, domainID, batchSize) + chs, err := createChannelsInBatch(ctx, s, conf, token, domainID, batchSize) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } channels = append(channels, chs...) } - chs, err := createChannelsInBatch(s, conf, domainID, token, conf.Num%uint64(batchSize)) + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } channels = append(channels, chs...) } else { - chs, err := createChannelsInBatch(s, conf, domainID, token, conf.Num) + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num) if err != nil { return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) } @@ -328,8 +324,7 @@ func createChannels(s sdk.SDK, conf Config, domainID, token string) ([]sdk.Chann return channels, nil } -func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func read(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { for _, user := range users { if _, err := s.User(ctx, user.ID, token); err != nil { return fmt.Errorf("failed to get user %w", err) @@ -382,8 +377,7 @@ func read(s sdk.SDK, conf Config, domainID, token string, users []sdk.User, grou return nil } -func update(s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func update(ctx context.Context, s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { for _, user := range users { user.FirstName = namesgenerator.Generate() user.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} @@ -545,8 +539,7 @@ func update(s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Gr return nil } -func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Client, channels []sdk.Channel) error { - ctx := context.Background() +func messaging(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, clients []sdk.Client, channels []sdk.Channel) error { for _, c := range clients { for _, channel := range channels { conn := sdk.Connection{ @@ -569,7 +562,7 @@ func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Cli func(num int64, client sdk.Client, channel sdk.Channel) { g.Go(func() error { msg := fmt.Sprintf(msgFormat, num+1, rand.Int()) - return sendHTTPMessage(s, msg, client, channel.ID) + return sendHTTPMessage(ctx, s, msg, client, channel.ID) }) g.Go(func() error { msg := fmt.Sprintf(msgFormat, num+2, rand.Int()) @@ -592,8 +585,7 @@ func messaging(s sdk.SDK, conf Config, domainID, token string, clients []sdk.Cli return g.Wait() } -func sendHTTPMessage(s sdk.SDK, msg string, client sdk.Client, chanID string) error { - ctx := context.Background() +func sendHTTPMessage(ctx context.Context, s sdk.SDK, msg string, client sdk.Client, chanID string) error { if err := s.SendMessage(ctx, client.DomainID, chanID, msg, client.Credentials.Secret); err != nil { return fmt.Errorf("HTTP failed to send message from client %s to channel %s: %w", client.ID, chanID, err) } diff --git a/tools/provision/cmd/main.go b/tools/provision/cmd/main.go index 5326ee930..632d8555a 100644 --- a/tools/provision/cmd/main.go +++ b/tools/provision/cmd/main.go @@ -19,8 +19,8 @@ func main() { Short: "provision is provisioning tool for SuperMQ", Long: `Tool for provisioning series of SuperMQ channels and clients and connecting them together. Complete documentation is available at https://docs.supermq.abstractmachines.fr`, - Run: func(_ *cobra.Command, _ []string) { - if err := provision.Provision(pconf); err != nil { + Run: func(cmd *cobra.Command, _ []string) { + if err := provision.Provision(cmd.Context(), pconf); err != nil { log.Fatal(err) } }, diff --git a/tools/provision/provision.go b/tools/provision/provision.go index 91bfe9bd4..80d8dd765 100644 --- a/tools/provision/provision.go +++ b/tools/provision/provision.go @@ -56,8 +56,7 @@ type Config struct { } // Provision - function that does actual provisiong. -func Provision(conf Config) error { - ctx := context.Background() +func Provision(ctx context.Context, conf Config) error { const ( rsaBits = 4096 ttl = "2400h"